Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import torchvision | |
| from tqdm.auto import tqdm | |
| CLASSES = ( | |
| "plane", | |
| "car", | |
| "bird", | |
| "cat", | |
| "deer", | |
| "dog", | |
| "frog", | |
| "horse", | |
| "ship", | |
| "truck", | |
| ) | |
| def main(): | |
| for split in ["train", "test"]: | |
| out_dir = f"cifar_{split}" | |
| if os.path.exists(out_dir): | |
| print(f"skipping split {split} since {out_dir} already exists.") | |
| continue | |
| print("downloading...") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| dataset = torchvision.datasets.CIFAR10( | |
| root=tmp_dir, train=split == "train", download=True | |
| ) | |
| print("dumping images...") | |
| os.mkdir(out_dir) | |
| for i in tqdm(range(len(dataset))): | |
| image, label = dataset[i] | |
| filename = os.path.join(out_dir, f"{CLASSES[label]}_{i:05d}.png") | |
| image.save(filename) | |
| if __name__ == "__main__": | |
| main() | |