stylegan2-flax-tpu / data_pipeline.py
akhaliq's picture
akhaliq HF staff
add files
81170fd
import tensorflow as tf
import tensorflow_datasets as tfds
import jax
import flax
import numpy as np
from PIL import Image
import os
from typing import Sequence
from tqdm import tqdm
import json
from tqdm import tqdm
import logging
logger = logging.getLogger(__name__)
def prefetch(dataset, n_prefetch):
# Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
ds_iter = iter(dataset)
ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
ds_iter)
if n_prefetch:
ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
return ds_iter
def get_data(data_dir, img_size, img_channels, num_classes, num_local_devices, batch_size, shuffle_buffer=1000):
"""
Args:
data_dir (str): Root directory of the dataset.
img_size (int): Image size for training.
img_channels (int): Number of image channels.
num_classes (int): Number of classes, 0 for no classes.
num_local_devices (int): Number of devices.
batch_size (int): Batch size (per device).
shuffle_buffer (int): Buffer used for shuffling the dataset.
Returns:
(tf.data.Dataset): Dataset.
"""
def pre_process(serialized_example):
feature = {'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'channels': tf.io.FixedLenFeature([], tf.int64),
'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64)}
example = tf.io.parse_single_example(serialized_example, feature)
height = tf.cast(example['height'], dtype=tf.int64)
width = tf.cast(example['width'], dtype=tf.int64)
channels = tf.cast(example['channels'], dtype=tf.int64)
image = tf.io.decode_raw(example['image'], out_type=tf.uint8)
image = tf.reshape(image, shape=[height, width, channels])
image = tf.cast(image, dtype='float32')
image = tf.image.resize(image, size=[img_size, img_size], method='bicubic', antialias=True)
image = tf.image.random_flip_left_right(image)
image = (image - 127.5) / 127.5
label = tf.one_hot(example['label'], num_classes)
return {'image': image, 'label': label}
def shard(data):
# Reshape images from [num_devices * batch_size, H, W, C] to [num_devices, batch_size, H, W, C]
# because the first dimension will be mapped across devices using jax.pmap
data['image'] = tf.reshape(data['image'], [num_local_devices, -1, img_size, img_size, img_channels])
data['label'] = tf.reshape(data['label'], [num_local_devices, -1, num_classes])
return data
logger.info('Loading TFRecord...')
with tf.io.gfile.GFile(os.path.join(data_dir, 'dataset_info.json'), 'r') as fin:
dataset_info = json.load(fin)
ds = tf.data.TFRecordDataset(filenames=os.path.join(data_dir, 'dataset.tfrecords'))
ds = ds.shard(jax.process_count(), jax.process_index())
ds = ds.shuffle(min(dataset_info['num_examples'], shuffle_buffer))
ds = ds.map(pre_process, tf.data.AUTOTUNE)
ds = ds.batch(batch_size * num_local_devices, drop_remainder=True) # uses per-worker batch size
ds = ds.map(shard, tf.data.AUTOTUNE)
ds = ds.prefetch(1) # prefetches the next batch
return ds, dataset_info