Spaces:
Sleeping
Sleeping
| """ A dataset parser that reads images from folders | |
| Folders are scannerd recursively to find image files. Labels are based | |
| on the folder hierarchy, just leaf folders by default. | |
| Hacked together by / Copyright 2020 Ross Wightman | |
| """ | |
| import os | |
| from timm.utils.misc import natural_key | |
| from .parser import Parser | |
| from .class_map import load_class_map | |
| from .constants import IMG_EXTENSIONS | |
| def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): | |
| labels = [] | |
| filenames = [] | |
| for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): | |
| rel_path = os.path.relpath(root, folder) if (root != folder) else '' | |
| label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') | |
| for f in files: | |
| base, ext = os.path.splitext(f) | |
| if ext.lower() in types: | |
| filenames.append(os.path.join(root, f)) | |
| labels.append(label) | |
| if class_to_idx is None: | |
| # building class index | |
| unique_labels = set(labels) | |
| sorted_labels = list(sorted(unique_labels, key=natural_key)) | |
| class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} | |
| images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx] | |
| if sort: | |
| images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) | |
| return images_and_targets, class_to_idx | |
| class ParserImageFolder(Parser): | |
| def __init__( | |
| self, | |
| root, | |
| class_map=''): | |
| super().__init__() | |
| self.root = root | |
| class_to_idx = None | |
| if class_map: | |
| class_to_idx = load_class_map(class_map, root) | |
| self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) | |
| if len(self.samples) == 0: | |
| raise RuntimeError( | |
| f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}') | |
| def __getitem__(self, index): | |
| path, target = self.samples[index] | |
| return open(path, 'rb'), target | |
| def __len__(self): | |
| return len(self.samples) | |
| def _filename(self, index, basename=False, absolute=False): | |
| filename = self.samples[index][0] | |
| if basename: | |
| filename = os.path.basename(filename) | |
| elif not absolute: | |
| filename = os.path.relpath(filename, self.root) | |
| return filename | |