Spaces:
Build error
Build error
| import tensorflow as tf | |
| import numpy as np | |
| from PIL import Image | |
| from typing import Sequence | |
| from tqdm import tqdm | |
| import argparse | |
| import json | |
| import os | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| def images_to_tfrecords(image_dir, data_dir, has_labels): | |
| """ | |
| Converts a folder of images to a TFRecord file. | |
| The image directory should have one of the following structures: | |
| If has_labels = False, image_dir should look like this: | |
| path/to/image_dir/ | |
| 0.jpg | |
| 1.jpg | |
| 2.jpg | |
| 4.jpg | |
| ... | |
| If has_labels = True, image_dir should look like this: | |
| path/to/image_dir/ | |
| label0/ | |
| 0.jpg | |
| 1.jpg | |
| ... | |
| label1/ | |
| a.jpg | |
| b.jpg | |
| c.jpg | |
| ... | |
| ... | |
| The labels will be label0 -> 0, label1 -> 1. | |
| Args: | |
| image_dir (str): Path to images. | |
| data_dir (str): Path where the TFrecords dataset is stored. | |
| has_labels (bool): If True, 'image_dir' contains label directories. | |
| Returns: | |
| (dict): Dataset info. | |
| """ | |
| def _bytes_feature(value): | |
| return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | |
| def _int64_feature(value): | |
| return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) | |
| os.makedirs(data_dir, exist_ok=True) | |
| writer = tf.io.TFRecordWriter(os.path.join(data_dir, 'dataset.tfrecords')) | |
| num_examples = 0 | |
| num_classes = 0 | |
| if has_labels: | |
| for label_dir in os.listdir(image_dir): | |
| if not os.path.isdir(os.path.join(image_dir, label_dir)): | |
| logger.warning('The image directory should contain one directory for each label.') | |
| logger.warning('These label directories should contain the image files.') | |
| if os.path.exists(os.path.join(data_dir, 'dataset.tfrecords')): | |
| os.remove(os.path.join(data_dir, 'dataset.tfrecords')) | |
| return | |
| for img_file in tqdm(os.listdir(os.path.join(image_dir, label_dir))): | |
| file_format = img_file[img_file.rfind('.') + 1:] | |
| if file_format not in ['png', 'jpg', 'jpeg']: | |
| continue | |
| #img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size) | |
| img = Image.open(os.path.join(image_dir, label_dir, img_file)) | |
| img = np.array(img, dtype=np.uint8) | |
| height = img.shape[0] | |
| width = img.shape[1] | |
| channels = img.shape[2] | |
| img_encoded = img.tobytes() | |
| example = tf.train.Example(features=tf.train.Features(feature={ | |
| 'height': _int64_feature(height), | |
| 'width': _int64_feature(width), | |
| 'channels': _int64_feature(channels), | |
| 'image': _bytes_feature(img_encoded), | |
| 'label': _int64_feature(num_classes)})) | |
| writer.write(example.SerializeToString()) | |
| num_examples += 1 | |
| num_classes += 1 | |
| else: | |
| for img_file in tqdm(os.listdir(os.path.join(image_dir))): | |
| file_format = img_file[img_file.rfind('.') + 1:] | |
| if file_format not in ['png', 'jpg', 'jpeg']: | |
| continue | |
| #img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size) | |
| img = Image.open(os.path.join(image_dir, img_file)) | |
| img = np.array(img, dtype=np.uint8) | |
| height = img.shape[0] | |
| width = img.shape[1] | |
| channels = img.shape[2] | |
| img_encoded = img.tobytes() | |
| example = tf.train.Example(features=tf.train.Features(feature={ | |
| 'height': _int64_feature(height), | |
| 'width': _int64_feature(width), | |
| 'channels': _int64_feature(channels), | |
| 'image': _bytes_feature(img_encoded), | |
| 'label': _int64_feature(num_classes)})) # dummy label | |
| writer.write(example.SerializeToString()) | |
| num_examples += 1 | |
| writer.close() | |
| dataset_info = {'num_examples': num_examples, 'num_classes': num_classes} | |
| with open(os.path.join(data_dir, 'dataset_info.json'), 'w') as fout: | |
| json.dump(dataset_info, fout) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--image_dir', type=str, help='Path to the image directory.') | |
| parser.add_argument('--data_dir', type=str, help='Path where the TFRecords dataset is stored.') | |
| parser.add_argument('--has_labels', action='store_true', help='If True, image_dir contains label directories.') | |
| args = parser.parse_args() | |
| images_to_tfrecords(args.image_dir, args.data_dir, args.has_labels) | |