Spaces:
Build error
Build error
import tensorflow as tf | |
import tensorflow_datasets as tfds | |
import jax | |
import flax | |
import numpy as np | |
from PIL import Image | |
import os | |
from typing import Sequence | |
from tqdm import tqdm | |
import json | |
from tqdm import tqdm | |
import logging | |
logger = logging.getLogger(__name__) | |
def prefetch(dataset, n_prefetch): | |
# Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py | |
ds_iter = iter(dataset) | |
ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x), | |
ds_iter) | |
if n_prefetch: | |
ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch) | |
return ds_iter | |
def get_data(data_dir, img_size, img_channels, num_classes, num_local_devices, batch_size, shuffle_buffer=1000): | |
""" | |
Args: | |
data_dir (str): Root directory of the dataset. | |
img_size (int): Image size for training. | |
img_channels (int): Number of image channels. | |
num_classes (int): Number of classes, 0 for no classes. | |
num_local_devices (int): Number of devices. | |
batch_size (int): Batch size (per device). | |
shuffle_buffer (int): Buffer used for shuffling the dataset. | |
Returns: | |
(tf.data.Dataset): Dataset. | |
""" | |
def pre_process(serialized_example): | |
feature = {'height': tf.io.FixedLenFeature([], tf.int64), | |
'width': tf.io.FixedLenFeature([], tf.int64), | |
'channels': tf.io.FixedLenFeature([], tf.int64), | |
'image': tf.io.FixedLenFeature([], tf.string), | |
'label': tf.io.FixedLenFeature([], tf.int64)} | |
example = tf.io.parse_single_example(serialized_example, feature) | |
height = tf.cast(example['height'], dtype=tf.int64) | |
width = tf.cast(example['width'], dtype=tf.int64) | |
channels = tf.cast(example['channels'], dtype=tf.int64) | |
image = tf.io.decode_raw(example['image'], out_type=tf.uint8) | |
image = tf.reshape(image, shape=[height, width, channels]) | |
image = tf.cast(image, dtype='float32') | |
image = tf.image.resize(image, size=[img_size, img_size], method='bicubic', antialias=True) | |
image = tf.image.random_flip_left_right(image) | |
image = (image - 127.5) / 127.5 | |
label = tf.one_hot(example['label'], num_classes) | |
return {'image': image, 'label': label} | |
def shard(data): | |
# Reshape images from [num_devices * batch_size, H, W, C] to [num_devices, batch_size, H, W, C] | |
# because the first dimension will be mapped across devices using jax.pmap | |
data['image'] = tf.reshape(data['image'], [num_local_devices, -1, img_size, img_size, img_channels]) | |
data['label'] = tf.reshape(data['label'], [num_local_devices, -1, num_classes]) | |
return data | |
logger.info('Loading TFRecord...') | |
with tf.io.gfile.GFile(os.path.join(data_dir, 'dataset_info.json'), 'r') as fin: | |
dataset_info = json.load(fin) | |
ds = tf.data.TFRecordDataset(filenames=os.path.join(data_dir, 'dataset.tfrecords')) | |
ds = ds.shard(jax.process_count(), jax.process_index()) | |
ds = ds.shuffle(min(dataset_info['num_examples'], shuffle_buffer)) | |
ds = ds.map(pre_process, tf.data.AUTOTUNE) | |
ds = ds.batch(batch_size * num_local_devices, drop_remainder=True) # uses per-worker batch size | |
ds = ds.map(shard, tf.data.AUTOTUNE) | |
ds = ds.prefetch(1) # prefetches the next batch | |
return ds, dataset_info | |