Spaces:
Build error
Build error
import tensorflow as tf | |
import numpy as np | |
from PIL import Image | |
from typing import Sequence | |
from tqdm import tqdm | |
import argparse | |
import json | |
import os | |
import logging | |
logger = logging.getLogger(__name__) | |
def images_to_tfrecords(image_dir, data_dir, has_labels): | |
""" | |
Converts a folder of images to a TFRecord file. | |
The image directory should have one of the following structures: | |
If has_labels = False, image_dir should look like this: | |
path/to/image_dir/ | |
0.jpg | |
1.jpg | |
2.jpg | |
4.jpg | |
... | |
If has_labels = True, image_dir should look like this: | |
path/to/image_dir/ | |
label0/ | |
0.jpg | |
1.jpg | |
... | |
label1/ | |
a.jpg | |
b.jpg | |
c.jpg | |
... | |
... | |
The labels will be label0 -> 0, label1 -> 1. | |
Args: | |
image_dir (str): Path to images. | |
data_dir (str): Path where the TFrecords dataset is stored. | |
has_labels (bool): If True, 'image_dir' contains label directories. | |
Returns: | |
(dict): Dataset info. | |
""" | |
def _bytes_feature(value): | |
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | |
def _int64_feature(value): | |
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) | |
os.makedirs(data_dir, exist_ok=True) | |
writer = tf.io.TFRecordWriter(os.path.join(data_dir, 'dataset.tfrecords')) | |
num_examples = 0 | |
num_classes = 0 | |
if has_labels: | |
for label_dir in os.listdir(image_dir): | |
if not os.path.isdir(os.path.join(image_dir, label_dir)): | |
logger.warning('The image directory should contain one directory for each label.') | |
logger.warning('These label directories should contain the image files.') | |
if os.path.exists(os.path.join(data_dir, 'dataset.tfrecords')): | |
os.remove(os.path.join(data_dir, 'dataset.tfrecords')) | |
return | |
for img_file in tqdm(os.listdir(os.path.join(image_dir, label_dir))): | |
file_format = img_file[img_file.rfind('.') + 1:] | |
if file_format not in ['png', 'jpg', 'jpeg']: | |
continue | |
#img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size) | |
img = Image.open(os.path.join(image_dir, label_dir, img_file)) | |
img = np.array(img, dtype=np.uint8) | |
height = img.shape[0] | |
width = img.shape[1] | |
channels = img.shape[2] | |
img_encoded = img.tobytes() | |
example = tf.train.Example(features=tf.train.Features(feature={ | |
'height': _int64_feature(height), | |
'width': _int64_feature(width), | |
'channels': _int64_feature(channels), | |
'image': _bytes_feature(img_encoded), | |
'label': _int64_feature(num_classes)})) | |
writer.write(example.SerializeToString()) | |
num_examples += 1 | |
num_classes += 1 | |
else: | |
for img_file in tqdm(os.listdir(os.path.join(image_dir))): | |
file_format = img_file[img_file.rfind('.') + 1:] | |
if file_format not in ['png', 'jpg', 'jpeg']: | |
continue | |
#img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size) | |
img = Image.open(os.path.join(image_dir, img_file)) | |
img = np.array(img, dtype=np.uint8) | |
height = img.shape[0] | |
width = img.shape[1] | |
channels = img.shape[2] | |
img_encoded = img.tobytes() | |
example = tf.train.Example(features=tf.train.Features(feature={ | |
'height': _int64_feature(height), | |
'width': _int64_feature(width), | |
'channels': _int64_feature(channels), | |
'image': _bytes_feature(img_encoded), | |
'label': _int64_feature(num_classes)})) # dummy label | |
writer.write(example.SerializeToString()) | |
num_examples += 1 | |
writer.close() | |
dataset_info = {'num_examples': num_examples, 'num_classes': num_classes} | |
with open(os.path.join(data_dir, 'dataset_info.json'), 'w') as fout: | |
json.dump(dataset_info, fout) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--image_dir', type=str, help='Path to the image directory.') | |
parser.add_argument('--data_dir', type=str, help='Path where the TFRecords dataset is stored.') | |
parser.add_argument('--has_labels', action='store_true', help='If True, image_dir contains label directories.') | |
args = parser.parse_args() | |
images_to_tfrecords(args.image_dir, args.data_dir, args.has_labels) | |