Spaces:
Sleeping
Sleeping
Commit
·
6ec3bf6
1
Parent(s):
590d01a
Finished MNIST downloading and caching modules
Browse files- notebooks/dataloader.ipynb +0 -198
- notebooks/dataset.ipynb +187 -0
- src/dataset.py +57 -0
- {datasets → src}/downloader.py +21 -13
notebooks/dataloader.ipynb
DELETED
|
@@ -1,198 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": 1,
|
| 6 |
-
"metadata": {
|
| 7 |
-
"collapsed": true,
|
| 8 |
-
"pycharm": {
|
| 9 |
-
"name": "#%%\n"
|
| 10 |
-
}
|
| 11 |
-
},
|
| 12 |
-
"outputs": [
|
| 13 |
-
{
|
| 14 |
-
"name": "stdout",
|
| 15 |
-
"output_type": "stream",
|
| 16 |
-
"text": [
|
| 17 |
-
"/mnt/c/Users/rzimm/Workspace/data/zero-to-hero\n"
|
| 18 |
-
]
|
| 19 |
-
}
|
| 20 |
-
],
|
| 21 |
-
"source": [
|
| 22 |
-
"%cd ..\n",
|
| 23 |
-
"%load_ext autoreload\n",
|
| 24 |
-
"%autoreload 2\n",
|
| 25 |
-
"from datasets import downloader"
|
| 26 |
-
]
|
| 27 |
-
},
|
| 28 |
-
{
|
| 29 |
-
"cell_type": "code",
|
| 30 |
-
"execution_count": 56,
|
| 31 |
-
"outputs": [],
|
| 32 |
-
"source": [
|
| 33 |
-
"import pandas as pd\n",
|
| 34 |
-
"import numpy as np\n",
|
| 35 |
-
"import random\n",
|
| 36 |
-
"from glob import glob, escape\n",
|
| 37 |
-
"import imageio.v2 as imageio"
|
| 38 |
-
],
|
| 39 |
-
"metadata": {
|
| 40 |
-
"collapsed": false,
|
| 41 |
-
"pycharm": {
|
| 42 |
-
"name": "#%%\n"
|
| 43 |
-
}
|
| 44 |
-
}
|
| 45 |
-
},
|
| 46 |
-
{
|
| 47 |
-
"cell_type": "code",
|
| 48 |
-
"execution_count": null,
|
| 49 |
-
"outputs": [],
|
| 50 |
-
"source": [
|
| 51 |
-
"# download.download(\"cityscapes\", \"datasets/downloaded\")"
|
| 52 |
-
],
|
| 53 |
-
"metadata": {
|
| 54 |
-
"collapsed": false,
|
| 55 |
-
"pycharm": {
|
| 56 |
-
"name": "#%%\n",
|
| 57 |
-
"is_executing": true
|
| 58 |
-
}
|
| 59 |
-
}
|
| 60 |
-
},
|
| 61 |
-
{
|
| 62 |
-
"cell_type": "code",
|
| 63 |
-
"execution_count": 41,
|
| 64 |
-
"outputs": [],
|
| 65 |
-
"source": [
|
| 66 |
-
"def load_dataset(name=\"gtFine\", path=\"datasets/downloads/\"):\n",
|
| 67 |
-
" src = path+name\n",
|
| 68 |
-
" test, train, val = [f\"{src}/{subpath}\" for subpath in [\"test\", \"train\", \"val\"]]\n",
|
| 69 |
-
"\n",
|
| 70 |
-
" dataset = {\"test\": glob(test + \"/*/*\"), \"train\": glob(train + \"/*/*\"), \"val\": glob(val + \"/*/*\")}\n",
|
| 71 |
-
"\n",
|
| 72 |
-
" return dataset"
|
| 73 |
-
],
|
| 74 |
-
"metadata": {
|
| 75 |
-
"collapsed": false,
|
| 76 |
-
"pycharm": {
|
| 77 |
-
"name": "#%%\n"
|
| 78 |
-
}
|
| 79 |
-
}
|
| 80 |
-
},
|
| 81 |
-
{
|
| 82 |
-
"cell_type": "code",
|
| 83 |
-
"execution_count": 44,
|
| 84 |
-
"outputs": [
|
| 85 |
-
{
|
| 86 |
-
"data": {
|
| 87 |
-
"text/plain": "list"
|
| 88 |
-
},
|
| 89 |
-
"execution_count": 44,
|
| 90 |
-
"metadata": {},
|
| 91 |
-
"output_type": "execute_result"
|
| 92 |
-
}
|
| 93 |
-
],
|
| 94 |
-
"source": [
|
| 95 |
-
"type(load_dataset()[\"train\"])"
|
| 96 |
-
],
|
| 97 |
-
"metadata": {
|
| 98 |
-
"collapsed": false,
|
| 99 |
-
"pycharm": {
|
| 100 |
-
"name": "#%%\n"
|
| 101 |
-
}
|
| 102 |
-
}
|
| 103 |
-
},
|
| 104 |
-
{
|
| 105 |
-
"cell_type": "code",
|
| 106 |
-
"execution_count": 45,
|
| 107 |
-
"outputs": [],
|
| 108 |
-
"source": [
|
| 109 |
-
"a = [1, 2, 3]"
|
| 110 |
-
],
|
| 111 |
-
"metadata": {
|
| 112 |
-
"collapsed": false,
|
| 113 |
-
"pycharm": {
|
| 114 |
-
"name": "#%%\n"
|
| 115 |
-
}
|
| 116 |
-
}
|
| 117 |
-
},
|
| 118 |
-
{
|
| 119 |
-
"cell_type": "code",
|
| 120 |
-
"execution_count": 143,
|
| 121 |
-
"outputs": [],
|
| 122 |
-
"source": [
|
| 123 |
-
"class DataLoader:\n",
|
| 124 |
-
" def __init__(self, data):\n",
|
| 125 |
-
" self.data = np.array(data)\n",
|
| 126 |
-
" self.total = len(self.data)\n",
|
| 127 |
-
" self.__items = self.data\n",
|
| 128 |
-
" self.__remaining = len(self.data)\n",
|
| 129 |
-
" def __next__(self, n=1):\n",
|
| 130 |
-
" if n > self.total:\n",
|
| 131 |
-
" raise ValueError(f\"Dataset doesn't have enough elements to suffice request of {n} elements.\")\n",
|
| 132 |
-
" if self.__remaining > 0:\n",
|
| 133 |
-
" indices = random.sample(range(self.__remaining), n)\n",
|
| 134 |
-
" sampled = self.__items[indices]\n",
|
| 135 |
-
" self.__items = np.delete(self.__items, indices)\n",
|
| 136 |
-
" self.__remaining -= n\n",
|
| 137 |
-
" return sampled\n",
|
| 138 |
-
" else:\n",
|
| 139 |
-
" self.__items = self.data\n",
|
| 140 |
-
" self.__remaining = len(self.data)\n",
|
| 141 |
-
" return self.__next__(n)"
|
| 142 |
-
],
|
| 143 |
-
"metadata": {
|
| 144 |
-
"collapsed": false,
|
| 145 |
-
"pycharm": {
|
| 146 |
-
"name": "#%%\n"
|
| 147 |
-
}
|
| 148 |
-
}
|
| 149 |
-
},
|
| 150 |
-
{
|
| 151 |
-
"cell_type": "code",
|
| 152 |
-
"execution_count": 144,
|
| 153 |
-
"outputs": [],
|
| 154 |
-
"source": [
|
| 155 |
-
"loader = DataLoader(a)"
|
| 156 |
-
],
|
| 157 |
-
"metadata": {
|
| 158 |
-
"collapsed": false,
|
| 159 |
-
"pycharm": {
|
| 160 |
-
"name": "#%%\n"
|
| 161 |
-
}
|
| 162 |
-
}
|
| 163 |
-
},
|
| 164 |
-
{
|
| 165 |
-
"cell_type": "code",
|
| 166 |
-
"execution_count": null,
|
| 167 |
-
"outputs": [],
|
| 168 |
-
"source": [],
|
| 169 |
-
"metadata": {
|
| 170 |
-
"collapsed": false,
|
| 171 |
-
"pycharm": {
|
| 172 |
-
"name": "#%%\n"
|
| 173 |
-
}
|
| 174 |
-
}
|
| 175 |
-
}
|
| 176 |
-
],
|
| 177 |
-
"metadata": {
|
| 178 |
-
"kernelspec": {
|
| 179 |
-
"display_name": "Python 3",
|
| 180 |
-
"language": "python",
|
| 181 |
-
"name": "python3"
|
| 182 |
-
},
|
| 183 |
-
"language_info": {
|
| 184 |
-
"codemirror_mode": {
|
| 185 |
-
"name": "ipython",
|
| 186 |
-
"version": 2
|
| 187 |
-
},
|
| 188 |
-
"file_extension": ".py",
|
| 189 |
-
"mimetype": "text/x-python",
|
| 190 |
-
"name": "python",
|
| 191 |
-
"nbconvert_exporter": "python",
|
| 192 |
-
"pygments_lexer": "ipython2",
|
| 193 |
-
"version": "2.7.6"
|
| 194 |
-
}
|
| 195 |
-
},
|
| 196 |
-
"nbformat": 4,
|
| 197 |
-
"nbformat_minor": 0
|
| 198 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/dataset.ipynb
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 2,
|
| 6 |
+
"outputs": [],
|
| 7 |
+
"source": [
|
| 8 |
+
"import os\n",
|
| 9 |
+
"import gzip"
|
| 10 |
+
],
|
| 11 |
+
"metadata": {
|
| 12 |
+
"collapsed": false,
|
| 13 |
+
"pycharm": {
|
| 14 |
+
"name": "#%%\n"
|
| 15 |
+
}
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"cell_type": "code",
|
| 20 |
+
"execution_count": 3,
|
| 21 |
+
"outputs": [
|
| 22 |
+
{
|
| 23 |
+
"name": "stderr",
|
| 24 |
+
"output_type": "stream",
|
| 25 |
+
"text": [
|
| 26 |
+
"/home/rzimmerdev/conda/envs/data/lib/python3.9/site-packages/_distutils_hack/__init__.py:33: UserWarning: Setuptools is replacing distutils.\n",
|
| 27 |
+
" warnings.warn(\"Setuptools is replacing distutils.\")\n"
|
| 28 |
+
]
|
| 29 |
+
}
|
| 30 |
+
],
|
| 31 |
+
"source": [
|
| 32 |
+
"from src.downloader import download_dataset\n",
|
| 33 |
+
"download_dataset(\"mnist\", \"../datasets/mnist\")"
|
| 34 |
+
],
|
| 35 |
+
"metadata": {
|
| 36 |
+
"collapsed": false,
|
| 37 |
+
"pycharm": {
|
| 38 |
+
"name": "#%%\n"
|
| 39 |
+
}
|
| 40 |
+
}
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"cell_type": "code",
|
| 44 |
+
"execution_count": 4,
|
| 45 |
+
"outputs": [
|
| 46 |
+
{
|
| 47 |
+
"data": {
|
| 48 |
+
"text/plain": "b'\\x00\\x00\\x08\\x03\\x00\\x00\\xea`\\x00\\x00\\x00\\x1c\\x00\\x00\\x00\\x1c'"
|
| 49 |
+
},
|
| 50 |
+
"execution_count": 4,
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"output_type": "execute_result"
|
| 53 |
+
}
|
| 54 |
+
],
|
| 55 |
+
"source": [
|
| 56 |
+
"f = gzip.open(\"../datasets/mnist/\" + os.listdir(\"../datasets/mnist/\")[0], 'r')\n",
|
| 57 |
+
"f.read(16)"
|
| 58 |
+
],
|
| 59 |
+
"metadata": {
|
| 60 |
+
"collapsed": false,
|
| 61 |
+
"pycharm": {
|
| 62 |
+
"name": "#%%\n"
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
"cell_type": "code",
|
| 68 |
+
"execution_count": 5,
|
| 69 |
+
"outputs": [],
|
| 70 |
+
"source": [
|
| 71 |
+
"import numpy as np\n",
|
| 72 |
+
"from torch.utils.data import DataLoader, Dataset"
|
| 73 |
+
],
|
| 74 |
+
"metadata": {
|
| 75 |
+
"collapsed": false,
|
| 76 |
+
"pycharm": {
|
| 77 |
+
"name": "#%%\n"
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"cell_type": "code",
|
| 83 |
+
"execution_count": 6,
|
| 84 |
+
"outputs": [],
|
| 85 |
+
"source": [
|
| 86 |
+
"class DatasetMNIST(Dataset):\n",
|
| 87 |
+
" def __init__(self, images, labels):\n",
|
| 88 |
+
" with gzip.open(images, 'r') as f:\n",
|
| 89 |
+
" f.read(4)\n",
|
| 90 |
+
" self.total = int.from_bytes(f.read(4), 'big')\n",
|
| 91 |
+
" rows = int.from_bytes(f.read(4), 'big')\n",
|
| 92 |
+
" columns = int.from_bytes(f.read(4), 'big')\n",
|
| 93 |
+
"\n",
|
| 94 |
+
" image_data = f.read()\n",
|
| 95 |
+
" images = np.frombuffer(image_data, dtype=np.uint8)\\\n",
|
| 96 |
+
" .reshape((self.total, rows, columns))\n",
|
| 97 |
+
" self.images = images\n",
|
| 98 |
+
" with gzip.open(labels, 'r') as f:\n",
|
| 99 |
+
" f.read(4)\n",
|
| 100 |
+
" total = int.from_bytes(f.read(4), 'big')\n",
|
| 101 |
+
"\n",
|
| 102 |
+
" label_data = f.read()\n",
|
| 103 |
+
" labels = np.frombuffer(label_data, dtype=np.uint8)\n",
|
| 104 |
+
" self.labels = labels\n",
|
| 105 |
+
" self.data = list(zip(self.images, self.labels))\n",
|
| 106 |
+
" def __getitem__(self, n):\n",
|
| 107 |
+
" if n > self.total:\n",
|
| 108 |
+
" raise ValueError(f\"Dataset doesn't have enough elements to suffice request of {n} elements.\")\n",
|
| 109 |
+
" return self.data[n]\n",
|
| 110 |
+
"\n",
|
| 111 |
+
" def __len__(self):\n",
|
| 112 |
+
" return len(self.data)"
|
| 113 |
+
],
|
| 114 |
+
"metadata": {
|
| 115 |
+
"collapsed": false,
|
| 116 |
+
"pycharm": {
|
| 117 |
+
"name": "#%%\n"
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"cell_type": "code",
|
| 123 |
+
"execution_count": 7,
|
| 124 |
+
"outputs": [],
|
| 125 |
+
"source": [
|
| 126 |
+
"dataset_dir = \"../datasets/mnist/\"\n",
|
| 127 |
+
"loader = DatasetMNIST(dataset_dir + \"train_images\", dataset_dir + \"train_labels\")"
|
| 128 |
+
],
|
| 129 |
+
"metadata": {
|
| 130 |
+
"collapsed": false,
|
| 131 |
+
"pycharm": {
|
| 132 |
+
"name": "#%%\n"
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"cell_type": "code",
|
| 138 |
+
"execution_count": 8,
|
| 139 |
+
"outputs": [
|
| 140 |
+
{
|
| 141 |
+
"data": {
|
| 142 |
+
"text/plain": "<Figure size 432x288 with 1 Axes>",
|
| 143 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAASY0lEQVR4nO3de9BU9X3H8ffHu4IGKZcgigZk2qi1JjKaTqjQ8VpnVNR6a42AmWBMYpNMbLXUqhm1STrV1GljJogWFJRg1HiJrTI2iqQpkaSoCBjUYkQIqEhUojHAt3+c83ROHnfPLnuX3+c1s/Ps7nfPOV+W/ew5e86e/SkiMLMd307dbsDMOsNhN0uEw26WCIfdLBEOu1kiHHazRDjstl0kXS1pTpXaRElr6pzPFEmLGuyh4WlT5rC3kKTHJL0hafcOLe8gSSFpl+2YZrWk49rZV6+RdIqkZZLelvRfkg7pdk/d4LC3iKSDgD8BAji1u91YH0ljgbnAZ4FBwAPA/dvzBrmjcNhb5wLgv4FZwORiQdIsSd+S9ANJb0laLGlMoR6SPitpVb5l8C1Jyms7SbpC0kuSNki6TdKH8kkX5n835WutP5Y0RtJ/Snpd0muS5koalM/rdmAU8ED++L/J7/9EvsbbJOkpSRMLvX1E0uN53wuAIfU+IZIul/RCPu1ySae//yH6F0m/krRS0rGFwock3SJpnaRXJF0raed6l11wIvBERCyKiC3AN4CRwIQG5vXBFhG+tOACPA98DjgS+C0wvFCbBWwEjgJ2IVvTzCvUA3iQbM0zCngVOCmvXZjPezQwELgHuD2vHZRPu0thXgcDxwO7A0PJ3hD+uVBfDRxXuD0SeB04mezN//j89tC8/mPghnx+xwBvAXOqPAcTgTWF22cB++XzPQfYDIzIa1OALcCXgV3z+q+AwXn9+8B3gAHAMOAnwEWFaRcVlvMgcHmVni4BHirc3hl4F/hit18zHX+NdruBHeECjM8DPiS/vRL4cqE+C5hZuH0ysLJwO4Dxhdvz+168wKPA5wq138+XtUulsFfobRLwP4Xb/cN+Wd+bR+G+h8m2TkblgRxQqN1Rb9gr1JcCp+XXpwBrARXqPwE+BQwHfgPsWaidB/ywMO2iasvpt8w/yN9kJgK7AX8PbAP+ttuvm05fvBnfGpOBRyLitfz2HfTblAd+Wbj+a7K1dD31/YCXCrWXyII+vFIjkoZJmpdv+r4JzKF80/tA4Kx8E36TpE1kb14j8mW/ERGb+y2/LpIukLS0MN/D+vXySuSJLMx7v7ynXYF1hWm/Q7aG3y4RsZLs/+JfgXX58pcDdR012JEkt5Oi1STtCZwN7CypL7C7A4Mk/VFEPNXkItaSvfj79K1t15Ntgvf3NbK1/eER8bqkSWQv9D79T3N8mWzN/pn+M5J0ILCvpAGFwI+qMI/3yae9GTgW+HFEbJW0FFDhYSMlqRD4UcD9eU+/IdtS2lJrWbVExPeA7+V9DSL7aPRks/P9oPGavXmTgK3AIcAR+eWjwBNkO+2adSfw5XxH2UDgH4Dv5iF4lWyTdHTh8XsDb5PttBsJ/HW/+a3v9/g5wCmSTpS0s6Q98uPl+0fES8AS4KuSdpM0Hjilzr4HkL0pvAogaSrZmr1oGPBXknaVdBbZ8/ZQRKwDHgGul7RPvpNyjKSGdqpJOjL/tw0l20J4IF/jJ8Vhb95k4N8i4hcR8cu+C9na9C9bcIjnVuB2sh1t/0u2c+kSgIj4NXAd8KN8c/cTwFeBj5Pt7PoB2Q69oq8BV+SPvzQiXgZOA6aTBfNlsjeIvtfGXwBHk+1gvAq4rZ6mI2I5cD3ZDr71wB8CP+r3sMXAWOC1/N/x5xHxel67gOwz9nLgDbI184hKy5L075Kml7RzI7AJeC7/+76tmBTodz8ymdmOymt2s0Q47GaJcNjNEuGwmyWio8fZJXlvoFmbRYQq3d/Uml3SSZKek/S8pMubmZeZtVfDh97yM5B+TnbixBqybySdlx9frTaN1+xmbdaONftRwPMR8WJEvAfMI/tyhpn1oGbCPpLs21Z91lDhu9qSpklaImlJE8sysyY1s4Ou0qbC+zbTI2IGMAO8GW/WTc2s2dcABxRu7092hpaZ9aBmwv4kMDY/G2s34Fyy0xPNrAc1vBkfEVskfYHsV012Bm6NiGdb1pmZtVRHz3rzZ3az9mvLl2rM7IPDYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q0PD47gKTVwFvAVmBLRIxrRVNm1npNhT33pxHxWgvmY2Zt5M14s0Q0G/YAHpH0U0nTKj1A0jRJSyQtaXJZZtYERUTjE0v7RcRaScOABcAlEbGw5PGNL8zM6hIRqnR/U2v2iFib/90A3Asc1cz8zKx9Gg67pAGS9u67DpwALGtVY2bWWs3sjR8O3Cupbz53RMR/tKQrM2u5pj6zb/fC/JndrO3a8pndzD44HHazRDjsZolw2M0S4bCbJaIVJ8JYDzv66KNL6+eff35pfcKECaX1Qw89dLt76nPppZeW1teuXVtaHz9+fGl9zpw5VWuLFy8unXZH5DW7WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIn/W2AzjnnHOq1m688cbSaYcMGVJaz09hruqxxx4rrQ8dOrRq7ZBDDimdtpZavd11111Va+eee25Ty+5lPuvNLHEOu1kiHHazRDjsZolw2M0S4bCbJcJhN0uEz2fvAbvsUv7fMG5c+eC4N998c9XaXnvtVTrtwoVVB/AB4JprrimtL1q0qLS+++67V63Nnz+/dNoTTjihtF7LkiUecazIa3azRDjsZolw2M0S4bCbJcJhN0uEw26WCIfdLBE+zt4Dav12+8yZMxue94IFC0rrZefCA7z55psNL7vW/Js9jr5mzZrS+uzZs5ua/46m5ppd0q2SNkhaVrhvsKQFklblf/dtb5tm1qx6NuNnASf1u+9y4NGIGAs8mt82sx5WM+wRsRDY2O/u04C+baTZwKTWtmVmrdboZ/bhEbEOICLWSRpW7YGSpgHTGlyOmbVI23fQRcQMYAb4ByfNuqnRQ2/rJY0AyP9uaF1LZtYOjYb9fmByfn0ycF9r2jGzdqn5u/GS7gQmAkOA9cBVwPeB+cAo4BfAWRHRfydepXkluRlf65zw6dOnl9Zr/R/ddNNNVWtXXHFF6bTNHkevZcWKFVVrY8eObWreZ555Zmn9vvvSXAdV+934mp/ZI+K8KqVjm+rIzDrKX5c1S4TDbpYIh90sEQ67WSIcdrNE+BTXFrjyyitL67UOrb333nul9Ycffri0ftlll1WtvfPOO6XT1rLHHnuU1mudpjpq1KiqtVpDLl977bWl9VQPrTXKa3azRDjsZolw2M0S4bCbJcJhN0uEw26WCIfdLBE1T3Ft6cI+wKe4Dho0qGpt5cqVpdMOGTKktP7ggw+W1idNmlRab8bBBx9cWp87d25p/cgjj2x42XfffXdp/cILLyytb968ueFl78iqneLqNbtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulggfZ6/TsGFVR7hi7dq1Tc179OjRpfV33323tD516tSqtVNPPbV02sMOO6y0PnDgwNJ6rddPWf2MM84onfaBBx4orVtlPs5uljiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCx9nrVHY+e9mwxABDhw4trdf6/fR2/h/V+o5Ard5GjBhRWn/11VcbntYa0/Bxdkm3StogaVnhvqslvSJpaX45uZXNmlnr1bMZPws4qcL934yII/LLQ61ty8xarWbYI2IhsLEDvZhZGzWzg+4Lkp7ON/P3rfYgSdMkLZG0pIllmVmTGg37t4ExwBHAOuD6ag+MiBkRMS4ixjW4LDNrgYbCHhHrI2JrRGwDbgaOam1bZtZqDYVdUvGYyenAsmqPNbPeUHN8dkl3AhOBIZLWAFcBEyUdAQSwGriofS32hk2bNlWt1fpd91q/Cz948ODS+gsvvFBaLxunfNasWaXTbtxYvu913rx5pfVax8prTW+dUzPsEXFehbtvaUMvZtZG/rqsWSIcdrNEOOxmiXDYzRLhsJsloubeeKtt8eLFpfVap7h20zHHHFNanzBhQml927ZtpfUXX3xxu3uy9vCa3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhI+zJ27PPfcsrdc6jl7rZ659imvv8JrdLBEOu1kiHHazRDjsZolw2M0S4bCbJcJhN0uEh2y2Ulu3bi2t13r9lP3UdNlwzta4hodsNrMdg8NulgiH3SwRDrtZIhx2s0Q47GaJcNjNElHPkM0HALcBHwa2ATMi4kZJg4HvAgeRDdt8dkS80b5WrR1OPPHEbrdgHVLPmn0L8JWI+CjwCeDzkg4BLgcejYixwKP5bTPrUTXDHhHrIuJn+fW3gBXASOA0YHb+sNnApDb1aGYtsF2f2SUdBHwMWAwMj4h1kL0hAMNa3p2ZtUzdv0EnaSBwN/CliHhTqvj120rTTQOmNdaembVKXWt2SbuSBX1uRNyT371e0oi8PgLYUGnaiJgREeMiYlwrGjazxtQMu7JV+C3Aioi4oVC6H5icX58M3Nf69sysVerZjP8k8CngGUlL8/umA18H5kv6NPAL4Ky2dGhtNXr06G63YB1SM+wRsQio9gH92Na2Y2bt4m/QmSXCYTdLhMNulgiH3SwRDrtZIhx2s0R4yObEPfHEE6X1nXYqXx/UGtLZeofX7GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZInycPXHLli0rra9ataq0Xut8+DFjxlStecjmzvKa3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhCKicwuTOrcwa4kpU6aU1mfOnFlaf/zxx6vWLrnkktJply9fXlq3yiKi4k+/e81ulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyWi5nF2SQcAtwEfBrYBMyLiRklXA58B+k5Knh4RD9WYl4+zf8Dss88+pfX58+eX1o877riqtXvuuad02qlTp5bWN2/eXFpPVbXj7PX8eMUW4CsR8TNJewM/lbQgr30zIv6pVU2aWfvUDHtErAPW5dffkrQCGNnuxsystbbrM7ukg4CPAYvzu74g6WlJt0rat8o00yQtkbSkuVbNrBl1h13SQOBu4EsR8SbwbWAMcATZmv/6StNFxIyIGBcR45pv18waVVfYJe1KFvS5EXEPQESsj4itEbENuBk4qn1tmlmzaoZdkoBbgBURcUPh/hGFh50OlP9MqZl1VT2H3sYDTwDPkB16A5gOnEe2CR/AauCifGde2bx86G0HU+vQ3HXXXVe1dvHFF5dOe/jhh5fWfQpsZQ0feouIRUCliUuPqZtZb/E36MwS4bCbJcJhN0uEw26WCIfdLBEOu1ki/FPSZjsY/5S0WeIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpaIen5dtpVeA14q3B6S39eLerW3Xu0L3FujWtnbgdUKHf1SzfsWLi3p1d+m69XeerUvcG+N6lRv3ow3S4TDbpaIbod9RpeXX6ZXe+vVvsC9NaojvXX1M7uZdU631+xm1iEOu1kiuhJ2SSdJek7S85Iu70YP1UhaLekZSUu7PT5dPobeBknLCvcNlrRA0qr8b8Ux9rrU29WSXsmfu6WSTu5SbwdI+qGkFZKelfTF/P6uPnclfXXkeev4Z3ZJOwM/B44H1gBPAudFRE/84r+k1cC4iOj6FzAkHQO8DdwWEYfl9/0jsDEivp6/Ue4bEZf1SG9XA293exjvfLSiEcVhxoFJwBS6+NyV9HU2HXjeurFmPwp4PiJejIj3gHnAaV3oo+dFxEJgY7+7TwNm59dnk71YOq5Kbz0hItZFxM/y628BfcOMd/W5K+mrI7oR9pHAy4Xba+it8d4DeETSTyVN63YzFQzvG2Yr/zusy/30V3MY707qN8x4zzx3jQx/3qxuhL3S72P10vG/T0bEx4E/Az6fb65afeoaxrtTKgwz3hMaHf68Wd0I+xrggMLt/YG1XeijoohYm//dANxL7w1Fvb5vBN3874Yu9/P/emkY70rDjNMDz103hz/vRtifBMZK+oik3YBzgfu70Mf7SBqQ7zhB0gDgBHpvKOr7gcn59cnAfV3s5Xf0yjDe1YYZp8vPXdeHP4+Ijl+Ak8n2yL8A/F03eqjS12jgqfzybLd7A+4k26z7LdkW0aeB3wMeBVblfwf3UG+3kw3t/TRZsEZ0qbfxZB8NnwaW5peTu/3clfTVkefNX5c1S4S/QWeWCIfdLBEOu1kiHHazRDjsZolw2M0S4bCbJeL/AHyD7vpJDzRWAAAAAElFTkSuQmCC\n"
|
| 144 |
+
},
|
| 145 |
+
"metadata": {
|
| 146 |
+
"needs_background": "light"
|
| 147 |
+
},
|
| 148 |
+
"output_type": "display_data"
|
| 149 |
+
}
|
| 150 |
+
],
|
| 151 |
+
"source": [
|
| 152 |
+
"import matplotlib.pyplot as plt\n",
|
| 153 |
+
"X, y = loader[4]\n",
|
| 154 |
+
"plt.imshow(X, cmap=\"gray\")\n",
|
| 155 |
+
"plt.title(label=\"Annotated label: \" + str(y))\n",
|
| 156 |
+
"plt.show()"
|
| 157 |
+
],
|
| 158 |
+
"metadata": {
|
| 159 |
+
"collapsed": false,
|
| 160 |
+
"pycharm": {
|
| 161 |
+
"name": "#%%\n"
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
],
|
| 166 |
+
"metadata": {
|
| 167 |
+
"kernelspec": {
|
| 168 |
+
"display_name": "Python 3",
|
| 169 |
+
"language": "python",
|
| 170 |
+
"name": "python3"
|
| 171 |
+
},
|
| 172 |
+
"language_info": {
|
| 173 |
+
"codemirror_mode": {
|
| 174 |
+
"name": "ipython",
|
| 175 |
+
"version": 2
|
| 176 |
+
},
|
| 177 |
+
"file_extension": ".py",
|
| 178 |
+
"mimetype": "text/x-python",
|
| 179 |
+
"name": "python",
|
| 180 |
+
"nbconvert_exporter": "python",
|
| 181 |
+
"pygments_lexer": "ipython2",
|
| 182 |
+
"version": "2.7.6"
|
| 183 |
+
}
|
| 184 |
+
},
|
| 185 |
+
"nbformat": 4,
|
| 186 |
+
"nbformat_minor": 0
|
| 187 |
+
}
|
src/dataset.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding: utf-8
|
| 3 |
+
import gzip
|
| 4 |
+
|
| 5 |
+
from src.downloader import download_dataset
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_mnist(download_dir):
|
| 12 |
+
download_dataset("mnist", download_dir)
|
| 13 |
+
|
| 14 |
+
return {"train": (download_dir + "train_images", download_dir + "train_labels"),
|
| 15 |
+
"test": (download_dir + "test_images", download_dir + "test_labels")}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DatasetMNIST(Dataset):
|
| 19 |
+
def __init__(self, images, labels):
|
| 20 |
+
with gzip.open(images, 'r') as f:
|
| 21 |
+
f.read(4)
|
| 22 |
+
self.total = int.from_bytes(f.read(4), 'big')
|
| 23 |
+
rows = int.from_bytes(f.read(4), 'big')
|
| 24 |
+
columns = int.from_bytes(f.read(4), 'big')
|
| 25 |
+
|
| 26 |
+
image_data = f.read()
|
| 27 |
+
images = np.frombuffer(image_data, dtype=np.uint8).reshape((self.total, rows, columns))
|
| 28 |
+
self.images = images
|
| 29 |
+
with gzip.open(labels, 'r') as f:
|
| 30 |
+
f.read(8)
|
| 31 |
+
|
| 32 |
+
label_data = f.read()
|
| 33 |
+
labels = np.frombuffer(label_data, dtype=np.uint8)
|
| 34 |
+
self.labels = labels
|
| 35 |
+
self.data = list(zip(self.images, self.labels))
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, n):
|
| 38 |
+
if n > self.total:
|
| 39 |
+
raise ValueError(f"Dataset doesn't have enough elements to suffice request of {n} elements.")
|
| 40 |
+
return self.data[n]
|
| 41 |
+
|
| 42 |
+
def __len__(self):
|
| 43 |
+
return len(self.data)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if __name__ == "__main__":
|
| 47 |
+
download_dir = "../downloads/mnist/"
|
| 48 |
+
mnist = load_mnist(download_dir)
|
| 49 |
+
|
| 50 |
+
dataset = DatasetMNIST(*mnist["train"])
|
| 51 |
+
|
| 52 |
+
import matplotlib.pyplot as plt
|
| 53 |
+
|
| 54 |
+
X, y = dataset[4]
|
| 55 |
+
plt.imshow(X, cmap="gray")
|
| 56 |
+
plt.title(label="Annotated label: " + str(y))
|
| 57 |
+
plt.show()
|
{datasets → src}/downloader.py
RENAMED
|
@@ -4,26 +4,38 @@
|
|
| 4 |
# To learn more about the dataset, access:
|
| 5 |
# https://www.cityscapes-dataset.com/
|
| 6 |
import os
|
| 7 |
-
import sys
|
| 8 |
import pip
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
pass
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def download(name='cityscapes', path='datasets/downloads'):
|
| 17 |
-
"""Select one of the available and implemented datasets to download:
|
| 18 |
name=any(['cityscapes', 'camvid', 'labelme'])
|
| 19 |
"""
|
| 20 |
if name == 'cityscapes':
|
| 21 |
download_cityscapes(path)
|
|
|
|
|
|
|
| 22 |
else:
|
| 23 |
raise NotImplementedError
|
| 24 |
|
| 25 |
|
| 26 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
if hasattr(pip, 'main'):
|
| 28 |
pip.main(['install', 'cityscapesscripts'])
|
| 29 |
else:
|
|
@@ -36,7 +48,3 @@ def download_cityscapes(path='datasets/downloads'):
|
|
| 36 |
print("Invalid dataset name. Please try again.")
|
| 37 |
ds_name = input()
|
| 38 |
os.system(f"csDownload {ds_name} -d {path}/{ds_name}")
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
if __name__ == "__main__":
|
| 42 |
-
main()
|
|
|
|
| 4 |
# To learn more about the dataset, access:
|
| 5 |
# https://www.cityscapes-dataset.com/
|
| 6 |
import os
|
|
|
|
| 7 |
import pip
|
| 8 |
+
from urllib.request import urlretrieve
|
| 9 |
|
| 10 |
|
| 11 |
+
def download_dataset(name='cityscapes', path='downloads/downloads'):
|
| 12 |
+
"""Select one of the available and implemented downloads to download:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
name=any(['cityscapes', 'camvid', 'labelme'])
|
| 14 |
"""
|
| 15 |
if name == 'cityscapes':
|
| 16 |
download_cityscapes(path)
|
| 17 |
+
elif name == "mnist":
|
| 18 |
+
pass
|
| 19 |
else:
|
| 20 |
raise NotImplementedError
|
| 21 |
|
| 22 |
|
| 23 |
+
def download_mnist(path="downloads/mnist"):
|
| 24 |
+
remote_files = {"train_images": "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
|
| 25 |
+
"train_labels": "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
|
| 26 |
+
"test_images": "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
|
| 27 |
+
"test_labels": "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"}
|
| 28 |
+
if not os.path.exists(path):
|
| 29 |
+
os.makedirs(path)
|
| 30 |
+
|
| 31 |
+
for file in remote_files.keys():
|
| 32 |
+
if os.path.exists(path + "/" + file):
|
| 33 |
+
continue
|
| 34 |
+
|
| 35 |
+
urlretrieve(remote_files[file], path + "/" + file)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def download_cityscapes(path='downloads/cityscapes'):
|
| 39 |
if hasattr(pip, 'main'):
|
| 40 |
pip.main(['install', 'cityscapesscripts'])
|
| 41 |
else:
|
|
|
|
| 48 |
print("Invalid dataset name. Please try again.")
|
| 49 |
ds_name = input()
|
| 50 |
os.system(f"csDownload {ds_name} -d {path}/{ds_name}")
|
|
|
|
|
|
|
|
|
|
|
|