Spaces:
Build error
Build error
add files
Browse files- checkpoint.py +96 -0
- data_pipeline.py +85 -0
- dataset_utils/crop_image_borders.py +57 -0
- dataset_utils/images_to_tfrecords.py +145 -0
- fid/__init__.py +1 -0
- fid/core.py +150 -0
- fid/inception.py +655 -0
- fid/utils.py +59 -0
- generate_images.py +61 -0
- main.py +102 -0
- requirements.txt +14 -0
- stylegan2/__init__.py +5 -0
- stylegan2/discriminator.py +451 -0
- stylegan2/generator.py +713 -0
- stylegan2/ops.py +674 -0
- stylegan2/utils.py +37 -0
- training.py +382 -0
- training_steps.py +219 -0
- training_utils.py +174 -0
checkpoint.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import flax
|
2 |
+
import dill as pickle
|
3 |
+
import os
|
4 |
+
import builtins
|
5 |
+
from jax._src.lib import xla_client
|
6 |
+
import tensorflow as tf
|
7 |
+
|
8 |
+
|
9 |
+
# Hack: this is the module reported by this object.
|
10 |
+
# https://github.com/google/jax/issues/8505
|
11 |
+
builtins.bfloat16 = xla_client.bfloat16
|
12 |
+
|
13 |
+
|
14 |
+
def pickle_dump(obj, filename):
|
15 |
+
""" Wrapper to dump an object to a file."""
|
16 |
+
with tf.io.gfile.GFile(filename, "wb") as f:
|
17 |
+
f.write(pickle.dumps(obj))
|
18 |
+
|
19 |
+
|
20 |
+
def pickle_load(filename):
|
21 |
+
""" Wrapper to load an object from a file."""
|
22 |
+
with tf.io.gfile.GFile(filename, 'rb') as f:
|
23 |
+
pickled = pickle.loads(f.read())
|
24 |
+
return pickled
|
25 |
+
|
26 |
+
|
27 |
+
def save_checkpoint(ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step, epoch, fid_score=None, keep=2):
|
28 |
+
"""
|
29 |
+
Saves checkpoint.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
ckpt_dir (str): Path to the directory, where checkpoints are saved.
|
33 |
+
state_G (train_state.TrainState): Generator state.
|
34 |
+
state_D (train_state.TrainState): Discriminator state.
|
35 |
+
params_ema_G (frozen_dict.FrozenDict): Parameters of the ema generator.
|
36 |
+
pl_mean (array): Moving average of the path length (generator regularization).
|
37 |
+
config (argparse.Namespace): Configuration.
|
38 |
+
step (int): Current step.
|
39 |
+
epoch (int): Current epoch.
|
40 |
+
fid_score (float): FID score corresponding to the checkpoint.
|
41 |
+
keep (int): Number of checkpoints to keep.
|
42 |
+
"""
|
43 |
+
state_dict = {'state_G': flax.jax_utils.unreplicate(state_G),
|
44 |
+
'state_D': flax.jax_utils.unreplicate(state_D),
|
45 |
+
'params_ema_G': params_ema_G,
|
46 |
+
'pl_mean': flax.jax_utils.unreplicate(pl_mean),
|
47 |
+
'config': config,
|
48 |
+
'fid_score': fid_score,
|
49 |
+
'step': step,
|
50 |
+
'epoch': epoch}
|
51 |
+
|
52 |
+
pickle_dump(state_dict, os.path.join(ckpt_dir, f'ckpt_{step}.pickle'))
|
53 |
+
ckpts = tf.io.gfile.glob(os.path.join(ckpt_dir, '*.pickle'))
|
54 |
+
if len(ckpts) > keep:
|
55 |
+
modified_times = {}
|
56 |
+
for ckpt in ckpts:
|
57 |
+
stats = tf.io.gfile.stat(ckpt)
|
58 |
+
modified_times[ckpt] = stats.mtime_nsec
|
59 |
+
oldest_ckpt = sorted(modified_times, key=modified_times.get)[0]
|
60 |
+
tf.io.gfile.remove(oldest_ckpt)
|
61 |
+
|
62 |
+
|
63 |
+
def load_checkpoint(filename):
|
64 |
+
"""
|
65 |
+
Loads checkpoints.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
filename (str): Path to the checkpoint file.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
(dict): Checkpoint.
|
72 |
+
"""
|
73 |
+
state_dict = pickle_load(filename)
|
74 |
+
return state_dict
|
75 |
+
|
76 |
+
|
77 |
+
def get_latest_checkpoint(ckpt_dir):
|
78 |
+
"""
|
79 |
+
Returns the path of the latest checkpoint.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
ckpt_dir (str): Path to the directory, where checkpoints are saved.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
(str): Path to latest checkpoint (if it exists).
|
86 |
+
"""
|
87 |
+
ckpts = tf.io.gfile.glob(os.path.join(ckpt_dir, '*.pickle'))
|
88 |
+
if len(ckpts) == 0:
|
89 |
+
return None
|
90 |
+
|
91 |
+
modified_times = {}
|
92 |
+
for ckpt in ckpts:
|
93 |
+
stats = tf.io.gfile.stat(ckpt)
|
94 |
+
modified_times[ckpt] = stats.mtime_nsec
|
95 |
+
latest_ckpt = sorted(modified_times, key=modified_times.get)[-1]
|
96 |
+
return latest_ckpt
|
data_pipeline.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import tensorflow_datasets as tfds
|
3 |
+
import jax
|
4 |
+
import flax
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
import os
|
8 |
+
from typing import Sequence
|
9 |
+
from tqdm import tqdm
|
10 |
+
import json
|
11 |
+
from tqdm import tqdm
|
12 |
+
import logging
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
def prefetch(dataset, n_prefetch):
|
18 |
+
# Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
|
19 |
+
ds_iter = iter(dataset)
|
20 |
+
ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
|
21 |
+
ds_iter)
|
22 |
+
if n_prefetch:
|
23 |
+
ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
|
24 |
+
return ds_iter
|
25 |
+
|
26 |
+
|
27 |
+
def get_data(data_dir, img_size, img_channels, num_classes, num_local_devices, batch_size, shuffle_buffer=1000):
|
28 |
+
"""
|
29 |
+
|
30 |
+
Args:
|
31 |
+
data_dir (str): Root directory of the dataset.
|
32 |
+
img_size (int): Image size for training.
|
33 |
+
img_channels (int): Number of image channels.
|
34 |
+
num_classes (int): Number of classes, 0 for no classes.
|
35 |
+
num_local_devices (int): Number of devices.
|
36 |
+
batch_size (int): Batch size (per device).
|
37 |
+
shuffle_buffer (int): Buffer used for shuffling the dataset.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
(tf.data.Dataset): Dataset.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def pre_process(serialized_example):
|
44 |
+
feature = {'height': tf.io.FixedLenFeature([], tf.int64),
|
45 |
+
'width': tf.io.FixedLenFeature([], tf.int64),
|
46 |
+
'channels': tf.io.FixedLenFeature([], tf.int64),
|
47 |
+
'image': tf.io.FixedLenFeature([], tf.string),
|
48 |
+
'label': tf.io.FixedLenFeature([], tf.int64)}
|
49 |
+
example = tf.io.parse_single_example(serialized_example, feature)
|
50 |
+
|
51 |
+
height = tf.cast(example['height'], dtype=tf.int64)
|
52 |
+
width = tf.cast(example['width'], dtype=tf.int64)
|
53 |
+
channels = tf.cast(example['channels'], dtype=tf.int64)
|
54 |
+
|
55 |
+
image = tf.io.decode_raw(example['image'], out_type=tf.uint8)
|
56 |
+
image = tf.reshape(image, shape=[height, width, channels])
|
57 |
+
|
58 |
+
image = tf.cast(image, dtype='float32')
|
59 |
+
image = tf.image.resize(image, size=[img_size, img_size], method='bicubic', antialias=True)
|
60 |
+
image = tf.image.random_flip_left_right(image)
|
61 |
+
|
62 |
+
image = (image - 127.5) / 127.5
|
63 |
+
|
64 |
+
label = tf.one_hot(example['label'], num_classes)
|
65 |
+
return {'image': image, 'label': label}
|
66 |
+
|
67 |
+
def shard(data):
|
68 |
+
# Reshape images from [num_devices * batch_size, H, W, C] to [num_devices, batch_size, H, W, C]
|
69 |
+
# because the first dimension will be mapped across devices using jax.pmap
|
70 |
+
data['image'] = tf.reshape(data['image'], [num_local_devices, -1, img_size, img_size, img_channels])
|
71 |
+
data['label'] = tf.reshape(data['label'], [num_local_devices, -1, num_classes])
|
72 |
+
return data
|
73 |
+
|
74 |
+
logger.info('Loading TFRecord...')
|
75 |
+
with tf.io.gfile.GFile(os.path.join(data_dir, 'dataset_info.json'), 'r') as fin:
|
76 |
+
dataset_info = json.load(fin)
|
77 |
+
|
78 |
+
ds = tf.data.TFRecordDataset(filenames=os.path.join(data_dir, 'dataset.tfrecords'))
|
79 |
+
ds = ds.shard(jax.process_count(), jax.process_index())
|
80 |
+
ds = ds.shuffle(min(dataset_info['num_examples'], shuffle_buffer))
|
81 |
+
ds = ds.map(pre_process, tf.data.AUTOTUNE)
|
82 |
+
ds = ds.batch(batch_size * num_local_devices, drop_remainder=True) # uses per-worker batch size
|
83 |
+
ds = ds.map(shard, tf.data.AUTOTUNE)
|
84 |
+
ds = ds.prefetch(1) # prefetches the next batch
|
85 |
+
return ds, dataset_info
|
dataset_utils/crop_image_borders.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
import os
|
4 |
+
from tqdm import tqdm
|
5 |
+
import argparse
|
6 |
+
import logging
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
"""
|
11 |
+
Crops the black borders around images.
|
12 |
+
"""
|
13 |
+
|
14 |
+
|
15 |
+
def crop_border(x, constant=0.0):
|
16 |
+
top = 0
|
17 |
+
while True:
|
18 |
+
if np.sum(x[top] != constant) != 0.0:
|
19 |
+
break
|
20 |
+
top += 1
|
21 |
+
bottom = x.shape[0] - 1
|
22 |
+
while True:
|
23 |
+
if np.sum(x[bottom] != constant) != 0.0:
|
24 |
+
bottom += 1
|
25 |
+
break
|
26 |
+
bottom -= 1
|
27 |
+
left = 0
|
28 |
+
while True:
|
29 |
+
if np.sum(x[:, left] != constant) != 0.0:
|
30 |
+
break
|
31 |
+
left += 1
|
32 |
+
right = x.shape[1] - 1
|
33 |
+
while True:
|
34 |
+
if np.sum(x[:, right] != constant) != 0.0:
|
35 |
+
right += 1
|
36 |
+
break
|
37 |
+
right -= 1
|
38 |
+
return x[top:bottom, left:right]
|
39 |
+
|
40 |
+
|
41 |
+
def crop_images(path, constant_value):
|
42 |
+
logger.info('Crop image borders...')
|
43 |
+
for f in tqdm(os.listdir(path)):
|
44 |
+
img = Image.open(os.path.join(path, f))
|
45 |
+
img = crop_border(np.array(img), constant=constant_value)
|
46 |
+
img = Image.fromarray(img)
|
47 |
+
img.save(os.path.join(path, f))
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == '__main__':
|
51 |
+
parser = argparse.ArgumentParser()
|
52 |
+
parser.add_argument('--image_dir', type=str, help='Path to the image directory.')
|
53 |
+
parser.add_argument('--constant_value', type=float, default=0.0, help='Value of the border that should be cropped.')
|
54 |
+
|
55 |
+
args = parser.parse_args()
|
56 |
+
|
57 |
+
crop_images(args.image_dir, args.constant_value)
|
dataset_utils/images_to_tfrecords.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from typing import Sequence
|
5 |
+
from tqdm import tqdm
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import logging
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
def images_to_tfrecords(image_dir, data_dir, has_labels):
|
15 |
+
"""
|
16 |
+
Converts a folder of images to a TFRecord file.
|
17 |
+
|
18 |
+
The image directory should have one of the following structures:
|
19 |
+
|
20 |
+
If has_labels = False, image_dir should look like this:
|
21 |
+
|
22 |
+
path/to/image_dir/
|
23 |
+
0.jpg
|
24 |
+
1.jpg
|
25 |
+
2.jpg
|
26 |
+
4.jpg
|
27 |
+
...
|
28 |
+
|
29 |
+
|
30 |
+
If has_labels = True, image_dir should look like this:
|
31 |
+
|
32 |
+
path/to/image_dir/
|
33 |
+
label0/
|
34 |
+
0.jpg
|
35 |
+
1.jpg
|
36 |
+
...
|
37 |
+
label1/
|
38 |
+
a.jpg
|
39 |
+
b.jpg
|
40 |
+
c.jpg
|
41 |
+
...
|
42 |
+
...
|
43 |
+
|
44 |
+
|
45 |
+
The labels will be label0 -> 0, label1 -> 1.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
image_dir (str): Path to images.
|
49 |
+
data_dir (str): Path where the TFrecords dataset is stored.
|
50 |
+
has_labels (bool): If True, 'image_dir' contains label directories.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
(dict): Dataset info.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def _bytes_feature(value):
|
57 |
+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
|
58 |
+
|
59 |
+
def _int64_feature(value):
|
60 |
+
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
|
61 |
+
|
62 |
+
os.makedirs(data_dir, exist_ok=True)
|
63 |
+
writer = tf.io.TFRecordWriter(os.path.join(data_dir, 'dataset.tfrecords'))
|
64 |
+
|
65 |
+
num_examples = 0
|
66 |
+
num_classes = 0
|
67 |
+
|
68 |
+
if has_labels:
|
69 |
+
for label_dir in os.listdir(image_dir):
|
70 |
+
if not os.path.isdir(os.path.join(image_dir, label_dir)):
|
71 |
+
logger.warning('The image directory should contain one directory for each label.')
|
72 |
+
logger.warning('These label directories should contain the image files.')
|
73 |
+
if os.path.exists(os.path.join(data_dir, 'dataset.tfrecords')):
|
74 |
+
os.remove(os.path.join(data_dir, 'dataset.tfrecords'))
|
75 |
+
return
|
76 |
+
|
77 |
+
for img_file in tqdm(os.listdir(os.path.join(image_dir, label_dir))):
|
78 |
+
file_format = img_file[img_file.rfind('.') + 1:]
|
79 |
+
if file_format not in ['png', 'jpg', 'jpeg']:
|
80 |
+
continue
|
81 |
+
|
82 |
+
#img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size)
|
83 |
+
img = Image.open(os.path.join(image_dir, label_dir, img_file))
|
84 |
+
img = np.array(img, dtype=np.uint8)
|
85 |
+
|
86 |
+
height = img.shape[0]
|
87 |
+
width = img.shape[1]
|
88 |
+
channels = img.shape[2]
|
89 |
+
|
90 |
+
img_encoded = img.tobytes()
|
91 |
+
|
92 |
+
example = tf.train.Example(features=tf.train.Features(feature={
|
93 |
+
'height': _int64_feature(height),
|
94 |
+
'width': _int64_feature(width),
|
95 |
+
'channels': _int64_feature(channels),
|
96 |
+
'image': _bytes_feature(img_encoded),
|
97 |
+
'label': _int64_feature(num_classes)}))
|
98 |
+
|
99 |
+
writer.write(example.SerializeToString())
|
100 |
+
num_examples += 1
|
101 |
+
|
102 |
+
num_classes += 1
|
103 |
+
else:
|
104 |
+
for img_file in tqdm(os.listdir(os.path.join(image_dir))):
|
105 |
+
file_format = img_file[img_file.rfind('.') + 1:]
|
106 |
+
if file_format not in ['png', 'jpg', 'jpeg']:
|
107 |
+
continue
|
108 |
+
|
109 |
+
#img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size)
|
110 |
+
img = Image.open(os.path.join(image_dir, img_file))
|
111 |
+
img = np.array(img, dtype=np.uint8)
|
112 |
+
|
113 |
+
height = img.shape[0]
|
114 |
+
width = img.shape[1]
|
115 |
+
channels = img.shape[2]
|
116 |
+
|
117 |
+
img_encoded = img.tobytes()
|
118 |
+
|
119 |
+
example = tf.train.Example(features=tf.train.Features(feature={
|
120 |
+
'height': _int64_feature(height),
|
121 |
+
'width': _int64_feature(width),
|
122 |
+
'channels': _int64_feature(channels),
|
123 |
+
'image': _bytes_feature(img_encoded),
|
124 |
+
'label': _int64_feature(num_classes)})) # dummy label
|
125 |
+
|
126 |
+
writer.write(example.SerializeToString())
|
127 |
+
num_examples += 1
|
128 |
+
|
129 |
+
writer.close()
|
130 |
+
|
131 |
+
dataset_info = {'num_examples': num_examples, 'num_classes': num_classes}
|
132 |
+
with open(os.path.join(data_dir, 'dataset_info.json'), 'w') as fout:
|
133 |
+
json.dump(dataset_info, fout)
|
134 |
+
|
135 |
+
|
136 |
+
if __name__ == '__main__':
|
137 |
+
parser = argparse.ArgumentParser()
|
138 |
+
parser.add_argument('--image_dir', type=str, help='Path to the image directory.')
|
139 |
+
parser.add_argument('--data_dir', type=str, help='Path where the TFRecords dataset is stored.')
|
140 |
+
parser.add_argument('--has_labels', action='store_true', help='If True, image_dir contains label directories.')
|
141 |
+
|
142 |
+
args = parser.parse_args()
|
143 |
+
|
144 |
+
images_to_tfrecords(args.image_dir, args.data_dir, args.has_labels)
|
145 |
+
|
fid/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .core import FID
|
fid/core.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jax
|
2 |
+
import jax.numpy as jnp
|
3 |
+
import flax
|
4 |
+
import flax.linen as nn
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import functools
|
8 |
+
import argparse
|
9 |
+
import scipy
|
10 |
+
from tqdm import tqdm
|
11 |
+
import logging
|
12 |
+
|
13 |
+
from . import inception
|
14 |
+
from . import utils
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
class FID:
|
19 |
+
|
20 |
+
def __init__(self, generator, dataset, config, use_cache=True, truncation_psi=1.0):
|
21 |
+
"""
|
22 |
+
Evaluates the FID score for a given generator and a given dataset.
|
23 |
+
Implementation mostly taken from https://github.com/matthias-wright/jax-fid
|
24 |
+
|
25 |
+
Reference: https://arxiv.org/abs/1706.08500
|
26 |
+
|
27 |
+
Args:
|
28 |
+
generator (nn.Module): Generator network.
|
29 |
+
dataset (tf.data.Dataset): Dataset containing the real images.
|
30 |
+
config (argparse.Namespace): Configuration.
|
31 |
+
use_cache (bool): If True, only compute the activation stats once for the real images and store them.
|
32 |
+
truncation_psi (float): Controls truncation (trading off variation for quality). If 1, truncation is disabled.
|
33 |
+
"""
|
34 |
+
self.num_images = config.num_fid_images
|
35 |
+
self.batch_size = config.batch_size
|
36 |
+
self.c_dim = config.c_dim
|
37 |
+
self.z_dim = config.z_dim
|
38 |
+
self.dataset = dataset
|
39 |
+
self.num_devices = jax.device_count()
|
40 |
+
self.num_local_devices = jax.local_device_count()
|
41 |
+
self.use_cache = use_cache
|
42 |
+
|
43 |
+
if self.use_cache:
|
44 |
+
self.cache = {}
|
45 |
+
|
46 |
+
rng = jax.random.PRNGKey(0)
|
47 |
+
inception_net = inception.InceptionV3(pretrained=True)
|
48 |
+
self.inception_params = inception_net.init(rng, jnp.ones((1, config.resolution, config.resolution, 3)))
|
49 |
+
self.inception_params = flax.jax_utils.replicate(self.inception_params)
|
50 |
+
#self.inception = jax.jit(functools.partial(model.apply, train=False))
|
51 |
+
self.inception_apply = jax.pmap(functools.partial(inception_net.apply, train=False), axis_name='batch')
|
52 |
+
|
53 |
+
self.generator_apply = jax.pmap(functools.partial(generator.apply, truncation_psi=truncation_psi, train=False, noise_mode='const'), axis_name='batch')
|
54 |
+
|
55 |
+
def compute_fid(self, generator_params, seed_offset=0):
|
56 |
+
generator_params = flax.jax_utils.replicate(generator_params)
|
57 |
+
mu_real, sigma_real = self.compute_stats_for_dataset()
|
58 |
+
mu_fake, sigma_fake = self.compute_stats_for_generator(generator_params, seed_offset)
|
59 |
+
fid_score = self.compute_frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake, eps=1e-6)
|
60 |
+
return fid_score
|
61 |
+
|
62 |
+
def compute_frechet_distance(self, mu1, mu2, sigma1, sigma2, eps=1e-6):
|
63 |
+
# Taken from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
|
64 |
+
mu1 = np.atleast_1d(mu1)
|
65 |
+
mu2 = np.atleast_1d(mu2)
|
66 |
+
sigma1 = np.atleast_1d(sigma1)
|
67 |
+
sigma2 = np.atleast_1d(sigma2)
|
68 |
+
|
69 |
+
assert mu1.shape == mu2.shape
|
70 |
+
assert sigma1.shape == sigma2.shape
|
71 |
+
|
72 |
+
diff = mu1 - mu2
|
73 |
+
|
74 |
+
covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
75 |
+
if not np.isfinite(covmean).all():
|
76 |
+
msg = ('fid calculation produces singular product; '
|
77 |
+
'adding %s to diagonal of cov estimates') % eps
|
78 |
+
logger.info(msg)
|
79 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
80 |
+
covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
81 |
+
|
82 |
+
# Numerical error might give slight imaginary component
|
83 |
+
if np.iscomplexobj(covmean):
|
84 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
85 |
+
m = np.max(np.abs(covmean.imag))
|
86 |
+
raise ValueError('Imaginary component {}'.format(m))
|
87 |
+
covmean = covmean.real
|
88 |
+
|
89 |
+
tr_covmean = np.trace(covmean)
|
90 |
+
return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)
|
91 |
+
|
92 |
+
def compute_stats_for_dataset(self):
|
93 |
+
if self.use_cache and 'mu' in self.cache and 'sigma' in self.cache:
|
94 |
+
logger.info('Use cached statistics for dataset...')
|
95 |
+
return self.cache['mu'], self.cache['sigma']
|
96 |
+
|
97 |
+
print()
|
98 |
+
logger.info('Compute statistics for dataset...')
|
99 |
+
image_count = 0
|
100 |
+
|
101 |
+
activations = []
|
102 |
+
for batch in utils.prefetch(self.dataset, n_prefetch=2):
|
103 |
+
act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(batch['image']))
|
104 |
+
act = jnp.reshape(act, (self.num_local_devices * self.batch_size, -1))
|
105 |
+
activations.append(act)
|
106 |
+
|
107 |
+
image_count += self.num_local_devices * self.batch_size
|
108 |
+
if image_count >= self.num_images:
|
109 |
+
break
|
110 |
+
|
111 |
+
activations = jnp.concatenate(activations, axis=0)
|
112 |
+
activations = activations[:self.num_images]
|
113 |
+
mu = np.mean(activations, axis=0)
|
114 |
+
sigma = np.cov(activations, rowvar=False)
|
115 |
+
self.cache['mu'] = mu
|
116 |
+
self.cache['sigma'] = sigma
|
117 |
+
return mu, sigma
|
118 |
+
|
119 |
+
def compute_stats_for_generator(self, generator_params, seed_offset):
|
120 |
+
print()
|
121 |
+
logger.info('Compute statistics for generator...')
|
122 |
+
num_batches = int(np.ceil(self.num_images / (self.batch_size * self.num_local_devices)))
|
123 |
+
|
124 |
+
activations = []
|
125 |
+
|
126 |
+
for i in range(num_batches):
|
127 |
+
rng = jax.random.PRNGKey(seed_offset + i)
|
128 |
+
z_latent = jax.random.normal(rng, shape=(self.num_local_devices, self.batch_size, self.z_dim))
|
129 |
+
|
130 |
+
labels = None
|
131 |
+
if self.c_dim > 0:
|
132 |
+
labels = jax.random.randint(rng, shape=(self.num_local_devices * self.batch_size,), minval=0, maxval=self.c_dim)
|
133 |
+
labels = jax.nn.one_hot(labels, num_classes=self.c_dim)
|
134 |
+
labels = jnp.reshape(labels, (self.num_local_devices, self.batch_size, self.c_dim))
|
135 |
+
|
136 |
+
image = self.generator_apply(generator_params, jax.lax.stop_gradient(z_latent), labels)
|
137 |
+
image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image))
|
138 |
+
|
139 |
+
image = 2 * image - 1
|
140 |
+
act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(image))
|
141 |
+
act = jnp.reshape(act, (self.num_local_devices * self.batch_size, -1))
|
142 |
+
activations.append(act)
|
143 |
+
|
144 |
+
activations = jnp.concatenate(activations, axis=0)
|
145 |
+
activations = activations[:self.num_images]
|
146 |
+
mu = np.mean(activations, axis=0)
|
147 |
+
sigma = np.cov(activations, rowvar=False)
|
148 |
+
return mu, sigma
|
149 |
+
|
150 |
+
|
fid/inception.py
ADDED
@@ -0,0 +1,655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jax
|
2 |
+
from jax import lax
|
3 |
+
from jax.nn import initializers
|
4 |
+
import jax.numpy as jnp
|
5 |
+
import flax
|
6 |
+
from flax.linen.module import merge_param
|
7 |
+
import flax.linen as nn
|
8 |
+
from typing import Callable, Iterable, Optional, Tuple, Union, Any
|
9 |
+
import functools
|
10 |
+
import pickle
|
11 |
+
from . import utils
|
12 |
+
|
13 |
+
PRNGKey = Any
|
14 |
+
Array = Any
|
15 |
+
Shape = Tuple[int]
|
16 |
+
Dtype = Any
|
17 |
+
|
18 |
+
|
19 |
+
class InceptionV3(nn.Module):
|
20 |
+
"""
|
21 |
+
InceptionV3 network.
|
22 |
+
Reference: https://arxiv.org/abs/1512.00567
|
23 |
+
Ported mostly from: https://github.com/pytorch/vision/blob/master/torchvision/models/inception.py
|
24 |
+
|
25 |
+
Attributes:
|
26 |
+
include_head (bool): If True, include classifier head.
|
27 |
+
num_classes (int): Number of classes.
|
28 |
+
pretrained (bool): If True, use pretrained weights.
|
29 |
+
transform_input (bool): If True, preprocesses the input according to the method with which it
|
30 |
+
was trained on ImageNet.
|
31 |
+
aux_logits (bool): If True, add an auxiliary branch that can improve training.
|
32 |
+
dtype (str): Data type.
|
33 |
+
"""
|
34 |
+
include_head: bool=False
|
35 |
+
num_classes: int=1000
|
36 |
+
pretrained: bool=False
|
37 |
+
transform_input: bool=False
|
38 |
+
aux_logits: bool=False
|
39 |
+
ckpt_path: str='https://www.dropbox.com/s/0zo4pd6cfwgzem7/inception_v3_weights_fid.pickle?dl=1'
|
40 |
+
dtype: str='float32'
|
41 |
+
|
42 |
+
def setup(self):
|
43 |
+
if self.pretrained:
|
44 |
+
ckpt_file = utils.download(self.ckpt_path)
|
45 |
+
self.params_dict = pickle.load(open(ckpt_file, 'rb'))
|
46 |
+
self.num_classes_ = 1000
|
47 |
+
else:
|
48 |
+
self.params_dict = None
|
49 |
+
self.num_classes_ = self.num_classes
|
50 |
+
|
51 |
+
@nn.compact
|
52 |
+
def __call__(self, x, train=True, rng=jax.random.PRNGKey(0)):
|
53 |
+
"""
|
54 |
+
Args:
|
55 |
+
x (tensor): Input image, shape [B, H, W, C].
|
56 |
+
train (bool): If True, training mode.
|
57 |
+
rng (jax.random.PRNGKey): Random seed.
|
58 |
+
"""
|
59 |
+
x = self._transform_input(x)
|
60 |
+
x = BasicConv2d(out_channels=32,
|
61 |
+
kernel_size=(3, 3),
|
62 |
+
strides=(2, 2),
|
63 |
+
params_dict=utils.get(self.params_dict, 'Conv2d_1a_3x3'),
|
64 |
+
dtype=self.dtype)(x, train)
|
65 |
+
x = BasicConv2d(out_channels=32,
|
66 |
+
kernel_size=(3, 3),
|
67 |
+
params_dict=utils.get(self.params_dict, 'Conv2d_2a_3x3'),
|
68 |
+
dtype=self.dtype)(x, train)
|
69 |
+
x = BasicConv2d(out_channels=64,
|
70 |
+
kernel_size=(3, 3),
|
71 |
+
padding=((1, 1), (1, 1)),
|
72 |
+
params_dict=utils.get(self.params_dict, 'Conv2d_2b_3x3'),
|
73 |
+
dtype=self.dtype)(x, train)
|
74 |
+
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
|
75 |
+
x = BasicConv2d(out_channels=80,
|
76 |
+
kernel_size=(1, 1),
|
77 |
+
params_dict=utils.get(self.params_dict, 'Conv2d_3b_1x1'),
|
78 |
+
dtype=self.dtype)(x, train)
|
79 |
+
x = BasicConv2d(out_channels=192,
|
80 |
+
kernel_size=(3, 3),
|
81 |
+
params_dict=utils.get(self.params_dict, 'Conv2d_4a_3x3'),
|
82 |
+
dtype=self.dtype)(x, train)
|
83 |
+
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
|
84 |
+
x = InceptionA(pool_features=32,
|
85 |
+
params_dict=utils.get(self.params_dict, 'Mixed_5b'),
|
86 |
+
dtype=self.dtype)(x, train)
|
87 |
+
x = InceptionA(pool_features=64,
|
88 |
+
params_dict=utils.get(self.params_dict, 'Mixed_5c'),
|
89 |
+
dtype=self.dtype)(x, train)
|
90 |
+
x = InceptionA(pool_features=64,
|
91 |
+
params_dict=utils.get(self.params_dict, 'Mixed_5d'),
|
92 |
+
dtype=self.dtype)(x, train)
|
93 |
+
x = InceptionB(params_dict=utils.get(self.params_dict, 'Mixed_6a'),
|
94 |
+
dtype=self.dtype)(x, train)
|
95 |
+
x = InceptionC(channels_7x7=128,
|
96 |
+
params_dict=utils.get(self.params_dict, 'Mixed_6b'),
|
97 |
+
dtype=self.dtype)(x, train)
|
98 |
+
x = InceptionC(channels_7x7=160,
|
99 |
+
params_dict=utils.get(self.params_dict, 'Mixed_6c'),
|
100 |
+
dtype=self.dtype)(x, train)
|
101 |
+
x = InceptionC(channels_7x7=160,
|
102 |
+
params_dict=utils.get(self.params_dict, 'Mixed_6d'),
|
103 |
+
dtype=self.dtype)(x, train)
|
104 |
+
x = InceptionC(channels_7x7=192,
|
105 |
+
params_dict=utils.get(self.params_dict, 'Mixed_6e'),
|
106 |
+
dtype=self.dtype)(x, train)
|
107 |
+
aux = None
|
108 |
+
if self.aux_logits and train:
|
109 |
+
aux = InceptionAux(num_classes=self.num_classes_,
|
110 |
+
params_dict=utils.get(self.params_dict, 'AuxLogits'),
|
111 |
+
dtype=self.dtype)(x, train)
|
112 |
+
x = InceptionD(params_dict=utils.get(self.params_dict, 'Mixed_7a'),
|
113 |
+
dtype=self.dtype)(x, train)
|
114 |
+
x = InceptionE(avg_pool, params_dict=utils.get(self.params_dict, 'Mixed_7b'),
|
115 |
+
dtype=self.dtype)(x, train)
|
116 |
+
# Following the implementation by @mseitzer, we use max pooling instead
|
117 |
+
# of average pooling here.
|
118 |
+
# See: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py#L320
|
119 |
+
x = InceptionE(nn.max_pool, params_dict=utils.get(self.params_dict, 'Mixed_7c'),
|
120 |
+
dtype=self.dtype)(x, train)
|
121 |
+
x = jnp.mean(x, axis=(1, 2), keepdims=True)
|
122 |
+
if not self.include_head:
|
123 |
+
return x
|
124 |
+
x = nn.Dropout(rate=0.5)(x, deterministic=not train, rng=rng)
|
125 |
+
x = jnp.reshape(x, newshape=(x.shape[0], -1))
|
126 |
+
x = Dense(features=self.num_classes_,
|
127 |
+
params_dict=utils.get(self.params_dict, 'fc'),
|
128 |
+
dtype=self.dtype)(x)
|
129 |
+
if self.aux_logits:
|
130 |
+
return x, aux
|
131 |
+
return x
|
132 |
+
|
133 |
+
def _transform_input(self, x):
|
134 |
+
if self.transform_input:
|
135 |
+
x_ch0 = jnp.expand_dims(x[..., 0], axis=-1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
|
136 |
+
x_ch1 = jnp.expand_dims(x[..., 1], axis=-1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
|
137 |
+
x_ch2 = jnp.expand_dims(x[..., 2], axis=-1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
|
138 |
+
x = jnp.concatenate((x_ch0, x_ch1, x_ch2), axis=-1)
|
139 |
+
return x
|
140 |
+
|
141 |
+
|
142 |
+
class Dense(nn.Module):
|
143 |
+
features: int
|
144 |
+
kernel_init: functools.partial=nn.initializers.lecun_normal()
|
145 |
+
bias_init: functools.partial=nn.initializers.zeros
|
146 |
+
params_dict: dict=None
|
147 |
+
dtype: str='float32'
|
148 |
+
|
149 |
+
@nn.compact
|
150 |
+
def __call__(self, x):
|
151 |
+
x = nn.Dense(features=self.features,
|
152 |
+
kernel_init=self.kernel_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['kernel']),
|
153 |
+
bias_init=self.bias_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['bias']))(x)
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
class BasicConv2d(nn.Module):
|
158 |
+
out_channels: int
|
159 |
+
kernel_size: Union[int, Iterable[int]]=(3, 3)
|
160 |
+
strides: Optional[Iterable[int]]=(1, 1)
|
161 |
+
padding: Union[str, Iterable[Tuple[int, int]]]='valid'
|
162 |
+
use_bias: bool=False
|
163 |
+
kernel_init: functools.partial=nn.initializers.lecun_normal()
|
164 |
+
bias_init: functools.partial=nn.initializers.zeros
|
165 |
+
params_dict: dict=None
|
166 |
+
dtype: str='float32'
|
167 |
+
|
168 |
+
@nn.compact
|
169 |
+
def __call__(self, x, train=True):
|
170 |
+
x = nn.Conv(features=self.out_channels,
|
171 |
+
kernel_size=self.kernel_size,
|
172 |
+
strides=self.strides,
|
173 |
+
padding=self.padding,
|
174 |
+
use_bias=self.use_bias,
|
175 |
+
kernel_init=self.kernel_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['conv']['kernel']),
|
176 |
+
bias_init=self.bias_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['conv']['bias']),
|
177 |
+
dtype=self.dtype)(x)
|
178 |
+
if self.params_dict is None:
|
179 |
+
x = BatchNorm(epsilon=0.001,
|
180 |
+
momentum=0.1,
|
181 |
+
use_running_average=not train,
|
182 |
+
dtype=self.dtype)(x)
|
183 |
+
else:
|
184 |
+
x = BatchNorm(epsilon=0.001,
|
185 |
+
momentum=0.1,
|
186 |
+
bias_init=lambda *_ : jnp.array(self.params_dict['bn']['bias']),
|
187 |
+
scale_init=lambda *_ : jnp.array(self.params_dict['bn']['scale']),
|
188 |
+
mean_init=lambda *_ : jnp.array(self.params_dict['bn']['mean']),
|
189 |
+
var_init=lambda *_ : jnp.array(self.params_dict['bn']['var']),
|
190 |
+
use_running_average=not train,
|
191 |
+
dtype=self.dtype)(x)
|
192 |
+
x = jax.nn.relu(x)
|
193 |
+
return x
|
194 |
+
|
195 |
+
|
196 |
+
class InceptionA(nn.Module):
|
197 |
+
pool_features: int
|
198 |
+
params_dict: dict=None
|
199 |
+
dtype: str='float32'
|
200 |
+
|
201 |
+
@nn.compact
|
202 |
+
def __call__(self, x, train=True):
|
203 |
+
branch1x1 = BasicConv2d(out_channels=64,
|
204 |
+
kernel_size=(1, 1),
|
205 |
+
params_dict=utils.get(self.params_dict, 'branch1x1'),
|
206 |
+
dtype=self.dtype)(x, train)
|
207 |
+
branch5x5 = BasicConv2d(out_channels=48,
|
208 |
+
kernel_size=(1, 1),
|
209 |
+
params_dict=utils.get(self.params_dict, 'branch5x5_1'),
|
210 |
+
dtype=self.dtype)(x, train)
|
211 |
+
branch5x5 = BasicConv2d(out_channels=64,
|
212 |
+
kernel_size=(5, 5),
|
213 |
+
padding=((2, 2), (2, 2)),
|
214 |
+
params_dict=utils.get(self.params_dict, 'branch5x5_2'),
|
215 |
+
dtype=self.dtype)(branch5x5, train)
|
216 |
+
|
217 |
+
branch3x3dbl = BasicConv2d(out_channels=64,
|
218 |
+
kernel_size=(1, 1),
|
219 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'),
|
220 |
+
dtype=self.dtype)(x, train)
|
221 |
+
branch3x3dbl = BasicConv2d(out_channels=96,
|
222 |
+
kernel_size=(3, 3),
|
223 |
+
padding=((1, 1), (1, 1)),
|
224 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'),
|
225 |
+
dtype=self.dtype)(branch3x3dbl, train)
|
226 |
+
branch3x3dbl = BasicConv2d(out_channels=96,
|
227 |
+
kernel_size=(3, 3),
|
228 |
+
padding=((1, 1), (1, 1)),
|
229 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_3'),
|
230 |
+
dtype=self.dtype)(branch3x3dbl, train)
|
231 |
+
|
232 |
+
branch_pool = avg_pool(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)))
|
233 |
+
branch_pool = BasicConv2d(out_channels=self.pool_features,
|
234 |
+
kernel_size=(1, 1),
|
235 |
+
params_dict=utils.get(self.params_dict, 'branch_pool'),
|
236 |
+
dtype=self.dtype)(branch_pool, train)
|
237 |
+
|
238 |
+
output = jnp.concatenate((branch1x1, branch5x5, branch3x3dbl, branch_pool), axis=-1)
|
239 |
+
return output
|
240 |
+
|
241 |
+
|
242 |
+
class InceptionB(nn.Module):
|
243 |
+
params_dict: dict=None
|
244 |
+
dtype: str='float32'
|
245 |
+
|
246 |
+
@nn.compact
|
247 |
+
def __call__(self, x, train=True):
|
248 |
+
branch3x3 = BasicConv2d(out_channels=384,
|
249 |
+
kernel_size=(3, 3),
|
250 |
+
strides=(2, 2),
|
251 |
+
params_dict=utils.get(self.params_dict, 'branch3x3'),
|
252 |
+
dtype=self.dtype)(x, train)
|
253 |
+
|
254 |
+
branch3x3dbl = BasicConv2d(out_channels=64,
|
255 |
+
kernel_size=(1, 1),
|
256 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'),
|
257 |
+
dtype=self.dtype)(x, train)
|
258 |
+
branch3x3dbl = BasicConv2d(out_channels=96,
|
259 |
+
kernel_size=(3, 3),
|
260 |
+
padding=((1, 1), (1, 1)),
|
261 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'),
|
262 |
+
dtype=self.dtype)(branch3x3dbl, train)
|
263 |
+
branch3x3dbl = BasicConv2d(out_channels=96,
|
264 |
+
kernel_size=(3, 3),
|
265 |
+
strides=(2, 2),
|
266 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_3'),
|
267 |
+
dtype=self.dtype)(branch3x3dbl, train)
|
268 |
+
|
269 |
+
branch_pool = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
|
270 |
+
|
271 |
+
output = jnp.concatenate((branch3x3, branch3x3dbl, branch_pool), axis=-1)
|
272 |
+
return output
|
273 |
+
|
274 |
+
|
275 |
+
class InceptionC(nn.Module):
|
276 |
+
channels_7x7: int
|
277 |
+
params_dict: dict=None
|
278 |
+
dtype: str='float32'
|
279 |
+
|
280 |
+
@nn.compact
|
281 |
+
def __call__(self, x, train=True):
|
282 |
+
branch1x1 = BasicConv2d(out_channels=192,
|
283 |
+
kernel_size=(1, 1),
|
284 |
+
params_dict=utils.get(self.params_dict, 'branch1x1'),
|
285 |
+
dtype=self.dtype)(x, train)
|
286 |
+
|
287 |
+
branch7x7 = BasicConv2d(out_channels=self.channels_7x7,
|
288 |
+
kernel_size=(1, 1),
|
289 |
+
params_dict=utils.get(self.params_dict, 'branch7x7_1'),
|
290 |
+
dtype=self.dtype)(x, train)
|
291 |
+
branch7x7 = BasicConv2d(out_channels=self.channels_7x7,
|
292 |
+
kernel_size=(1, 7),
|
293 |
+
padding=((0, 0), (3, 3)),
|
294 |
+
params_dict=utils.get(self.params_dict, 'branch7x7_2'),
|
295 |
+
dtype=self.dtype)(branch7x7, train)
|
296 |
+
branch7x7 = BasicConv2d(out_channels=192,
|
297 |
+
kernel_size=(7, 1),
|
298 |
+
padding=((3, 3), (0, 0)),
|
299 |
+
params_dict=utils.get(self.params_dict, 'branch7x7_3'),
|
300 |
+
dtype=self.dtype)(branch7x7, train)
|
301 |
+
|
302 |
+
branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
|
303 |
+
kernel_size=(1, 1),
|
304 |
+
params_dict=utils.get(self.params_dict, 'branch7x7dbl_1'),
|
305 |
+
dtype=self.dtype)(x, train)
|
306 |
+
branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
|
307 |
+
kernel_size=(7, 1),
|
308 |
+
padding=((3, 3), (0, 0)),
|
309 |
+
params_dict=utils.get(self.params_dict, 'branch7x7dbl_2'),
|
310 |
+
dtype=self.dtype)(branch7x7dbl, train)
|
311 |
+
branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
|
312 |
+
kernel_size=(1, 7),
|
313 |
+
padding=((0, 0), (3, 3)),
|
314 |
+
params_dict=utils.get(self.params_dict, 'branch7x7dbl_3'),
|
315 |
+
dtype=self.dtype)(branch7x7dbl, train)
|
316 |
+
branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
|
317 |
+
kernel_size=(7, 1),
|
318 |
+
padding=((3, 3), (0, 0)),
|
319 |
+
params_dict=utils.get(self.params_dict, 'branch7x7dbl_4'),
|
320 |
+
dtype=self.dtype)(branch7x7dbl, train)
|
321 |
+
branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
|
322 |
+
kernel_size=(1, 7),
|
323 |
+
padding=((0, 0), (3, 3)),
|
324 |
+
params_dict=utils.get(self.params_dict, 'branch7x7dbl_5'),
|
325 |
+
dtype=self.dtype)(branch7x7dbl, train)
|
326 |
+
|
327 |
+
branch_pool = avg_pool(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)))
|
328 |
+
branch_pool = BasicConv2d(out_channels=192,
|
329 |
+
kernel_size=(1, 1),
|
330 |
+
params_dict=utils.get(self.params_dict, 'branch_pool'),
|
331 |
+
dtype=self.dtype)(branch_pool, train)
|
332 |
+
|
333 |
+
output = jnp.concatenate((branch1x1, branch7x7, branch7x7dbl, branch_pool), axis=-1)
|
334 |
+
return output
|
335 |
+
|
336 |
+
|
337 |
+
class InceptionD(nn.Module):
|
338 |
+
params_dict: dict=None
|
339 |
+
dtype: str='float32'
|
340 |
+
|
341 |
+
@nn.compact
|
342 |
+
def __call__(self, x, train=True):
|
343 |
+
branch3x3 = BasicConv2d(out_channels=192,
|
344 |
+
kernel_size=(1, 1),
|
345 |
+
params_dict=utils.get(self.params_dict, 'branch3x3_1'),
|
346 |
+
dtype=self.dtype)(x, train)
|
347 |
+
branch3x3 = BasicConv2d(out_channels=320,
|
348 |
+
kernel_size=(3, 3),
|
349 |
+
strides=(2, 2),
|
350 |
+
params_dict=utils.get(self.params_dict, 'branch3x3_2'),
|
351 |
+
dtype=self.dtype)(branch3x3, train)
|
352 |
+
|
353 |
+
branch7x7x3 = BasicConv2d(out_channels=192,
|
354 |
+
kernel_size=(1, 1),
|
355 |
+
params_dict=utils.get(self.params_dict, 'branch7x7x3_1'),
|
356 |
+
dtype=self.dtype)(x, train)
|
357 |
+
branch7x7x3 = BasicConv2d(out_channels=192,
|
358 |
+
kernel_size=(1, 7),
|
359 |
+
padding=((0, 0), (3, 3)),
|
360 |
+
params_dict=utils.get(self.params_dict, 'branch7x7x3_2'),
|
361 |
+
dtype=self.dtype)(branch7x7x3, train)
|
362 |
+
branch7x7x3 = BasicConv2d(out_channels=192,
|
363 |
+
kernel_size=(7, 1),
|
364 |
+
padding=((3, 3), (0, 0)),
|
365 |
+
params_dict=utils.get(self.params_dict, 'branch7x7x3_3'),
|
366 |
+
dtype=self.dtype)(branch7x7x3, train)
|
367 |
+
branch7x7x3 = BasicConv2d(out_channels=192,
|
368 |
+
kernel_size=(3, 3),
|
369 |
+
strides=(2, 2),
|
370 |
+
params_dict=utils.get(self.params_dict, 'branch7x7x3_4'),
|
371 |
+
dtype=self.dtype)(branch7x7x3, train)
|
372 |
+
|
373 |
+
branch_pool = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
|
374 |
+
|
375 |
+
output = jnp.concatenate((branch3x3, branch7x7x3, branch_pool), axis=-1)
|
376 |
+
return output
|
377 |
+
|
378 |
+
|
379 |
+
class InceptionE(nn.Module):
|
380 |
+
pooling: Callable
|
381 |
+
params_dict: dict=None
|
382 |
+
dtype: str='float32'
|
383 |
+
|
384 |
+
@nn.compact
|
385 |
+
def __call__(self, x, train=True):
|
386 |
+
branch1x1 = BasicConv2d(out_channels=320,
|
387 |
+
kernel_size=(1, 1),
|
388 |
+
params_dict=utils.get(self.params_dict, 'branch1x1'),
|
389 |
+
dtype=self.dtype)(x, train)
|
390 |
+
|
391 |
+
branch3x3 = BasicConv2d(out_channels=384,
|
392 |
+
kernel_size=(1, 1),
|
393 |
+
params_dict=utils.get(self.params_dict, 'branch3x3_1'),
|
394 |
+
dtype=self.dtype)(x, train)
|
395 |
+
branch3x3_a = BasicConv2d(out_channels=384,
|
396 |
+
kernel_size=(1, 3),
|
397 |
+
padding=((0, 0), (1, 1)),
|
398 |
+
params_dict=utils.get(self.params_dict, 'branch3x3_2a'),
|
399 |
+
dtype=self.dtype)(branch3x3, train)
|
400 |
+
branch3x3_b = BasicConv2d(out_channels=384,
|
401 |
+
kernel_size=(3, 1),
|
402 |
+
padding=((1, 1), (0, 0)),
|
403 |
+
params_dict=utils.get(self.params_dict, 'branch3x3_2b'),
|
404 |
+
dtype=self.dtype)(branch3x3, train)
|
405 |
+
branch3x3 = jnp.concatenate((branch3x3_a, branch3x3_b), axis=-1)
|
406 |
+
|
407 |
+
branch3x3dbl = BasicConv2d(out_channels=448,
|
408 |
+
kernel_size=(1, 1),
|
409 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'),
|
410 |
+
dtype=self.dtype)(x, train)
|
411 |
+
branch3x3dbl = BasicConv2d(out_channels=384,
|
412 |
+
kernel_size=(3, 3),
|
413 |
+
padding=((1, 1), (1, 1)),
|
414 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'),
|
415 |
+
dtype=self.dtype)(branch3x3dbl, train)
|
416 |
+
branch3x3dbl_a = BasicConv2d(out_channels=384,
|
417 |
+
kernel_size=(1, 3),
|
418 |
+
padding=((0, 0), (1, 1)),
|
419 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_3a'),
|
420 |
+
dtype=self.dtype)(branch3x3dbl, train)
|
421 |
+
branch3x3dbl_b = BasicConv2d(out_channels=384,
|
422 |
+
kernel_size=(3, 1),
|
423 |
+
padding=((1, 1), (0, 0)),
|
424 |
+
params_dict=utils.get(self.params_dict, 'branch3x3dbl_3b'),
|
425 |
+
dtype=self.dtype)(branch3x3dbl, train)
|
426 |
+
branch3x3dbl = jnp.concatenate((branch3x3dbl_a, branch3x3dbl_b), axis=-1)
|
427 |
+
|
428 |
+
branch_pool = self.pooling(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)))
|
429 |
+
branch_pool = BasicConv2d(out_channels=192,
|
430 |
+
kernel_size=(1, 1),
|
431 |
+
params_dict=utils.get(self.params_dict, 'branch_pool'),
|
432 |
+
dtype=self.dtype)(branch_pool, train)
|
433 |
+
|
434 |
+
output = jnp.concatenate((branch1x1, branch3x3, branch3x3dbl, branch_pool), axis=-1)
|
435 |
+
return output
|
436 |
+
|
437 |
+
|
438 |
+
class InceptionAux(nn.Module):
|
439 |
+
num_classes: int
|
440 |
+
kernel_init: functools.partial=nn.initializers.lecun_normal()
|
441 |
+
bias_init: functools.partial=nn.initializers.zeros
|
442 |
+
params_dict: dict=None
|
443 |
+
dtype: str='float32'
|
444 |
+
|
445 |
+
@nn.compact
|
446 |
+
def __call__(self, x, train=True):
|
447 |
+
x = avg_pool(x, window_shape=(5, 5), strides=(3, 3))
|
448 |
+
x = BasicConv2d(out_channels=128,
|
449 |
+
kernel_size=(1, 1),
|
450 |
+
params_dict=utils.get(self.params_dict, 'conv0'),
|
451 |
+
dtype=self.dtype)(x, train)
|
452 |
+
x = BasicConv2d(out_channels=768,
|
453 |
+
kernel_size=(5, 5),
|
454 |
+
params_dict=utils.get(self.params_dict, 'conv1'),
|
455 |
+
dtype=self.dtype)(x, train)
|
456 |
+
x = jnp.mean(x, axis=(1, 2))
|
457 |
+
x = jnp.reshape(x, newshape=(x.shape[0], -1))
|
458 |
+
x = Dense(features=self.num_classes,
|
459 |
+
params_dict=utils.get(self.params_dict, 'fc'),
|
460 |
+
dtype=self.dtype)(x)
|
461 |
+
return x
|
462 |
+
|
463 |
+
def _absolute_dims(rank, dims):
|
464 |
+
return tuple([rank + dim if dim < 0 else dim for dim in dims])
|
465 |
+
|
466 |
+
|
467 |
+
class BatchNorm(nn.Module):
|
468 |
+
"""BatchNorm Module.
|
469 |
+
Taken from: https://github.com/google/flax/blob/master/flax/linen/normalization.py
|
470 |
+
Attributes:
|
471 |
+
use_running_average: if True, the statistics stored in batch_stats
|
472 |
+
will be used instead of computing the batch statistics on the input.
|
473 |
+
axis: the feature or non-batch axis of the input.
|
474 |
+
momentum: decay rate for the exponential moving average of the batch statistics.
|
475 |
+
epsilon: a small float added to variance to avoid dividing by zero.
|
476 |
+
dtype: the dtype of the computation (default: float32).
|
477 |
+
use_bias: if True, bias (beta) is added.
|
478 |
+
use_scale: if True, multiply by scale (gamma).
|
479 |
+
When the next layer is linear (also e.g. nn.relu), this can be disabled
|
480 |
+
since the scaling will be done by the next layer.
|
481 |
+
bias_init: initializer for bias, by default, zero.
|
482 |
+
scale_init: initializer for scale, by default, one.
|
483 |
+
axis_name: the axis name used to combine batch statistics from multiple
|
484 |
+
devices. See `jax.pmap` for a description of axis names (default: None).
|
485 |
+
axis_index_groups: groups of axis indices within that named axis
|
486 |
+
representing subsets of devices to reduce over (default: None). For
|
487 |
+
example, `[[0, 1], [2, 3]]` would independently batch-normalize over
|
488 |
+
the examples on the first two and last two devices. See `jax.lax.psum`
|
489 |
+
for more details.
|
490 |
+
"""
|
491 |
+
use_running_average: Optional[bool] = None
|
492 |
+
axis: int = -1
|
493 |
+
momentum: float = 0.99
|
494 |
+
epsilon: float = 1e-5
|
495 |
+
dtype: Dtype = jnp.float32
|
496 |
+
use_bias: bool = True
|
497 |
+
use_scale: bool = True
|
498 |
+
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
|
499 |
+
scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
|
500 |
+
mean_init: Callable[[Shape], Array] = lambda s: jnp.zeros(s, jnp.float32)
|
501 |
+
var_init: Callable[[Shape], Array] = lambda s: jnp.ones(s, jnp.float32)
|
502 |
+
axis_name: Optional[str] = None
|
503 |
+
axis_index_groups: Any = None
|
504 |
+
|
505 |
+
@nn.compact
|
506 |
+
def __call__(self, x, use_running_average: Optional[bool] = None):
|
507 |
+
"""Normalizes the input using batch statistics.
|
508 |
+
|
509 |
+
NOTE:
|
510 |
+
During initialization (when parameters are mutable) the running average
|
511 |
+
of the batch statistics will not be updated. Therefore, the inputs
|
512 |
+
fed during initialization don't need to match that of the actual input
|
513 |
+
distribution and the reduction axis (set with `axis_name`) does not have
|
514 |
+
to exist.
|
515 |
+
Args:
|
516 |
+
x: the input to be normalized.
|
517 |
+
use_running_average: if true, the statistics stored in batch_stats
|
518 |
+
will be used instead of computing the batch statistics on the input.
|
519 |
+
Returns:
|
520 |
+
Normalized inputs (the same shape as inputs).
|
521 |
+
"""
|
522 |
+
use_running_average = merge_param(
|
523 |
+
'use_running_average', self.use_running_average, use_running_average)
|
524 |
+
x = jnp.asarray(x, jnp.float32)
|
525 |
+
axis = self.axis if isinstance(self.axis, tuple) else (self.axis,)
|
526 |
+
axis = _absolute_dims(x.ndim, axis)
|
527 |
+
feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape))
|
528 |
+
reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis)
|
529 |
+
reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)
|
530 |
+
|
531 |
+
# see NOTE above on initialization behavior
|
532 |
+
initializing = self.is_mutable_collection('params')
|
533 |
+
|
534 |
+
ra_mean = self.variable('batch_stats', 'mean',
|
535 |
+
self.mean_init,
|
536 |
+
reduced_feature_shape)
|
537 |
+
ra_var = self.variable('batch_stats', 'var',
|
538 |
+
self.var_init,
|
539 |
+
reduced_feature_shape)
|
540 |
+
|
541 |
+
if use_running_average:
|
542 |
+
mean, var = ra_mean.value, ra_var.value
|
543 |
+
else:
|
544 |
+
mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
|
545 |
+
mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False)
|
546 |
+
if self.axis_name is not None and not initializing:
|
547 |
+
concatenated_mean = jnp.concatenate([mean, mean2])
|
548 |
+
mean, mean2 = jnp.split(
|
549 |
+
lax.pmean(
|
550 |
+
concatenated_mean,
|
551 |
+
axis_name=self.axis_name,
|
552 |
+
axis_index_groups=self.axis_index_groups), 2)
|
553 |
+
var = mean2 - lax.square(mean)
|
554 |
+
|
555 |
+
if not initializing:
|
556 |
+
ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean
|
557 |
+
ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var
|
558 |
+
|
559 |
+
y = x - mean.reshape(feature_shape)
|
560 |
+
mul = lax.rsqrt(var + self.epsilon)
|
561 |
+
if self.use_scale:
|
562 |
+
scale = self.param('scale',
|
563 |
+
self.scale_init,
|
564 |
+
reduced_feature_shape).reshape(feature_shape)
|
565 |
+
mul = mul * scale
|
566 |
+
y = y * mul
|
567 |
+
if self.use_bias:
|
568 |
+
bias = self.param('bias',
|
569 |
+
self.bias_init,
|
570 |
+
reduced_feature_shape).reshape(feature_shape)
|
571 |
+
y = y + bias
|
572 |
+
return jnp.asarray(y, self.dtype)
|
573 |
+
|
574 |
+
|
575 |
+
def pool(inputs, init, reduce_fn, window_shape, strides, padding):
|
576 |
+
"""
|
577 |
+
Taken from: https://github.com/google/flax/blob/main/flax/linen/pooling.py
|
578 |
+
|
579 |
+
Helper function to define pooling functions.
|
580 |
+
Pooling functions are implemented using the ReduceWindow XLA op.
|
581 |
+
NOTE: Be aware that pooling is not generally differentiable.
|
582 |
+
That means providing a reduce_fn that is differentiable does not imply
|
583 |
+
that pool is differentiable.
|
584 |
+
Args:
|
585 |
+
inputs: input data with dimensions (batch, window dims..., features).
|
586 |
+
init: the initial value for the reduction
|
587 |
+
reduce_fn: a reduce function of the form `(T, T) -> T`.
|
588 |
+
window_shape: a shape tuple defining the window to reduce over.
|
589 |
+
strides: a sequence of `n` integers, representing the inter-window
|
590 |
+
strides.
|
591 |
+
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
|
592 |
+
of `n` `(low, high)` integer pairs that give the padding to apply before
|
593 |
+
and after each spatial dimension.
|
594 |
+
Returns:
|
595 |
+
The output of the reduction for each window slice.
|
596 |
+
"""
|
597 |
+
strides = strides or (1,) * len(window_shape)
|
598 |
+
assert len(window_shape) == len(strides), (
|
599 |
+
f"len({window_shape}) == len({strides})")
|
600 |
+
strides = (1,) + strides + (1,)
|
601 |
+
dims = (1,) + window_shape + (1,)
|
602 |
+
|
603 |
+
is_single_input = False
|
604 |
+
if inputs.ndim == len(dims) - 1:
|
605 |
+
# add singleton batch dimension because lax.reduce_window always
|
606 |
+
# needs a batch dimension.
|
607 |
+
inputs = inputs[None]
|
608 |
+
is_single_input = True
|
609 |
+
|
610 |
+
assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})"
|
611 |
+
if not isinstance(padding, str):
|
612 |
+
padding = tuple(map(tuple, padding))
|
613 |
+
assert(len(padding) == len(window_shape)), (
|
614 |
+
f"padding {padding} must specify pads for same number of dims as "
|
615 |
+
f"window_shape {window_shape}")
|
616 |
+
assert(all([len(x) == 2 for x in padding])), (
|
617 |
+
f"each entry in padding {padding} must be length 2")
|
618 |
+
padding = ((0,0),) + padding + ((0,0),)
|
619 |
+
y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
|
620 |
+
if is_single_input:
|
621 |
+
y = jnp.squeeze(y, axis=0)
|
622 |
+
return y
|
623 |
+
|
624 |
+
|
625 |
+
def avg_pool(inputs, window_shape, strides=None, padding='VALID'):
|
626 |
+
"""
|
627 |
+
Pools the input by taking the average over a window.
|
628 |
+
|
629 |
+
In comparison to flax.linen.avg_pool, this pooling operation does not
|
630 |
+
consider the padded zero's for the average computation.
|
631 |
+
|
632 |
+
Args:
|
633 |
+
inputs: input data with dimensions (batch, window dims..., features).
|
634 |
+
window_shape: a shape tuple defining the window to reduce over.
|
635 |
+
strides: a sequence of `n` integers, representing the inter-window
|
636 |
+
strides (default: `(1, ..., 1)`).
|
637 |
+
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
|
638 |
+
of `n` `(low, high)` integer pairs that give the padding to apply before
|
639 |
+
and after each spatial dimension (default: `'VALID'`).
|
640 |
+
Returns:
|
641 |
+
The average for each window slice.
|
642 |
+
"""
|
643 |
+
assert inputs.ndim == 4
|
644 |
+
assert len(window_shape) == 2
|
645 |
+
|
646 |
+
y = pool(inputs, 0., jax.lax.add, window_shape, strides, padding)
|
647 |
+
ones = jnp.ones(shape=(1, inputs.shape[1], inputs.shape[2], 1)).astype(inputs.dtype)
|
648 |
+
counts = jax.lax.conv_general_dilated(ones,
|
649 |
+
jnp.expand_dims(jnp.ones(window_shape).astype(inputs.dtype), axis=(-2, -1)),
|
650 |
+
window_strides=(1, 1),
|
651 |
+
padding=((1, 1), (1, 1)),
|
652 |
+
dimension_numbers=nn.linear._conv_dimension_numbers(ones.shape),
|
653 |
+
feature_group_count=1)
|
654 |
+
y = y / counts
|
655 |
+
return y
|
fid/utils.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jax
|
2 |
+
import flax
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
import requests
|
6 |
+
import os
|
7 |
+
import tempfile
|
8 |
+
import logging
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
def download(url, ckpt_dir=None):
|
14 |
+
name = url[url.rfind('/') + 1 : url.rfind('?')]
|
15 |
+
if ckpt_dir is None:
|
16 |
+
ckpt_dir = tempfile.gettempdir()
|
17 |
+
ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels')
|
18 |
+
ckpt_file = os.path.join(ckpt_dir, name)
|
19 |
+
if not os.path.exists(ckpt_file):
|
20 |
+
logger.info(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}')
|
21 |
+
if not os.path.exists(ckpt_dir):
|
22 |
+
os.makedirs(ckpt_dir)
|
23 |
+
|
24 |
+
response = requests.get(url, stream=True)
|
25 |
+
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
26 |
+
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
27 |
+
|
28 |
+
# first create temp file, in case the download fails
|
29 |
+
ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp')
|
30 |
+
with open(ckpt_file_temp, 'wb') as file:
|
31 |
+
for data in response.iter_content(chunk_size=1024):
|
32 |
+
progress_bar.update(len(data))
|
33 |
+
file.write(data)
|
34 |
+
progress_bar.close()
|
35 |
+
|
36 |
+
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
37 |
+
logger.error('An error occured while downloading, please try again.')
|
38 |
+
if os.path.exists(ckpt_file_temp):
|
39 |
+
os.remove(ckpt_file_temp)
|
40 |
+
else:
|
41 |
+
# if download was successful, rename the temp file
|
42 |
+
os.rename(ckpt_file_temp, ckpt_file)
|
43 |
+
return ckpt_file
|
44 |
+
|
45 |
+
|
46 |
+
def get(dictionary, key):
|
47 |
+
if dictionary is None or key not in dictionary:
|
48 |
+
return None
|
49 |
+
return dictionary[key]
|
50 |
+
|
51 |
+
|
52 |
+
def prefetch(dataset, n_prefetch):
|
53 |
+
# Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
|
54 |
+
ds_iter = iter(dataset)
|
55 |
+
ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
|
56 |
+
ds_iter)
|
57 |
+
if n_prefetch:
|
58 |
+
ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
|
59 |
+
return ds_iter
|
generate_images.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import functools
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
|
6 |
+
import jax
|
7 |
+
import jax.numpy as jnp
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
import checkpoint
|
13 |
+
from stylegan2.generator import Generator
|
14 |
+
|
15 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)-5.5s] [%(name)-12.12s]: %(message)s', force=True)
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
def generate_images(args):
|
20 |
+
logger.info(f"Loading checking '{args.checkpoint}'...")
|
21 |
+
ckpt = checkpoint.load_checkpoint(args.checkpoint)
|
22 |
+
config = ckpt['config']
|
23 |
+
params_ema_G = ckpt['params_ema_G']
|
24 |
+
|
25 |
+
generator_ema = Generator(
|
26 |
+
resolution=config.resolution,
|
27 |
+
num_channels=config.img_channels,
|
28 |
+
z_dim=config.z_dim,
|
29 |
+
c_dim=config.c_dim,
|
30 |
+
w_dim=config.w_dim,
|
31 |
+
num_ws=int(np.log2(config.resolution)) * 2 - 3,
|
32 |
+
num_mapping_layers=8,
|
33 |
+
fmap_base=config.fmap_base,
|
34 |
+
dtype=jnp.float32
|
35 |
+
)
|
36 |
+
|
37 |
+
generator_apply = jax.jit(
|
38 |
+
functools.partial(generator_ema.apply, truncation_psi=args.truncation_psi, train=False, noise_mode='const')
|
39 |
+
)
|
40 |
+
|
41 |
+
logger.info(f"Generating {len(args.seeds)} images with truncation {args.truncation_psi}...")
|
42 |
+
for seed in tqdm(args.seeds):
|
43 |
+
rng = jax.random.PRNGKey(seed)
|
44 |
+
z_latent = jax.random.normal(rng, shape=(1, config.z_dim))
|
45 |
+
image = generator_apply(params_ema_G, jax.lax.stop_gradient(z_latent), None)
|
46 |
+
image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image))
|
47 |
+
|
48 |
+
Image.fromarray(np.uint8(np.clip(image[0] * 255, 0, 255))).save(os.path.join(args.out_path, f'{seed}.png'))
|
49 |
+
logger.info(f"Images saved in '{args.out_path}/'")
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
parser = argparse.ArgumentParser()
|
54 |
+
parser.add_argument('--checkpoint', type=str, help='Path to the checkpoint.', required=True)
|
55 |
+
parser.add_argument('--out_path', type=str, default='generated_images', help='Path where the generated images are stored.')
|
56 |
+
parser.add_argument('--truncation_psi', type=float, default=0.5, help='Controls truncation (trading off variation for quality). If 1, truncation is disabled.')
|
57 |
+
parser.add_argument('--seeds', type=int, nargs='*', default=[0], help='List of random seeds.')
|
58 |
+
args = parser.parse_args()
|
59 |
+
os.makedirs(args.out_path, exist_ok=True)
|
60 |
+
|
61 |
+
generate_images(args)
|
main.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import jax
|
4 |
+
import wandb
|
5 |
+
import training
|
6 |
+
import logging
|
7 |
+
import json
|
8 |
+
|
9 |
+
|
10 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)-5.5s] [%(name)-12.12s]: %(message)s', force=True)
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
def main():
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
# Paths
|
17 |
+
parser.add_argument('--data_dir', type=str, required=True, help='Directory of the dataset.')
|
18 |
+
parser.add_argument('--save_dir', type=str, default='gs://ig-standard-usc1/sg2-flax/checkpoints/', help='Directory where checkpoints will be written to. A subfolder with run_id will be created.')
|
19 |
+
parser.add_argument('--load_from_pkl', type=str, help='If provided, start training from an existing checkpoint pickle file.')
|
20 |
+
parser.add_argument('--resume_run_id', type=str, help='If provided, resume existing training run. If --wandb is enabled W&B will also resume.')
|
21 |
+
parser.add_argument('--project', type=str, default='sg2-flax', help='Name of this project.')
|
22 |
+
# Training
|
23 |
+
parser.add_argument('--num_epochs', type=int, default=10000, help='Number of epochs.')
|
24 |
+
parser.add_argument('--learning_rate', type=float, default=0.002, help='Learning rate.')
|
25 |
+
parser.add_argument('--batch_size', type=int, default=8, help='Batch size.')
|
26 |
+
parser.add_argument('--num_prefetch', type=int, default=2, help='Number of prefetched examples for the data pipeline.')
|
27 |
+
parser.add_argument('--resolution', type=int, default=128, help='Image resolution. Must be a multiple of 2.')
|
28 |
+
parser.add_argument('--img_channels', type=int, default=3, help='Number of image channels.')
|
29 |
+
parser.add_argument('--mixed_precision', action='store_true', help='Use mixed precision training.')
|
30 |
+
parser.add_argument('--random_seed', type=int, default=0, help='Random seed.')
|
31 |
+
parser.add_argument('--bf16', action='store_true', help='Use bf16 dtype (This is still WIP).')
|
32 |
+
# Generator
|
33 |
+
parser.add_argument('--fmap_base', type=int, default=16384, help='Overall multiplier for the number of feature maps.')
|
34 |
+
# Discriminator
|
35 |
+
parser.add_argument('--mbstd_group_size', type=int, help='Group size for the minibatch standard deviation layer, None = entire minibatch.')
|
36 |
+
# Exponentially Moving Average of Generator Weights
|
37 |
+
parser.add_argument('--ema_kimg', type=float, default=20.0, help='Controls the ema of the generator weights (larger value -> larger beta).')
|
38 |
+
# Losses
|
39 |
+
parser.add_argument('--pl_decay', type=float, default=0.01, help='Exponentially decay for mean of path length (Path length regul).')
|
40 |
+
parser.add_argument('--pl_weight', type=float, default=2, help='Weight for path length regularization.')
|
41 |
+
# Regularization
|
42 |
+
parser.add_argument('--mixing_prob', type=float, default=0.9, help='Probability for style mixing.')
|
43 |
+
parser.add_argument('--G_reg_interval', type=int, default=4, help='How often to perform regularization for G.')
|
44 |
+
parser.add_argument('--D_reg_interval', type=int, default=16, help='How often to perform regularization for D.')
|
45 |
+
parser.add_argument('--r1_gamma', type=float, default=10.0, help='Weight for R1 regularization.')
|
46 |
+
# Model
|
47 |
+
parser.add_argument('--z_dim', type=int, default=512, help='Input latent (Z) dimensionality.')
|
48 |
+
parser.add_argument('--c_dim', type=int, default=0, help='Conditioning label (C) dimensionality, 0 = no label.')
|
49 |
+
parser.add_argument('--w_dim', type=int, default=512, help='Conditioning label (W) dimensionality.')
|
50 |
+
# Logging
|
51 |
+
parser.add_argument('--log_every', type=int, default=100, help='Log every log_every steps.')
|
52 |
+
parser.add_argument('--save_every', type=int, default=2000, help='Save every save_every steps. Will be ignored if FID evaluation is enabled.')
|
53 |
+
parser.add_argument('--generate_samples_every', type=int, default=10000, help='Generate samples every generate_samples_every steps.')
|
54 |
+
parser.add_argument('--debug', action='store_true', help='Show debug log.')
|
55 |
+
# FID
|
56 |
+
parser.add_argument('--eval_fid_every', type=int, default=1000, help='Compute FID score every eval_fid_every steps.')
|
57 |
+
parser.add_argument('--num_fid_images', type=int, default=10000, help='Number of images to use for FID computation.')
|
58 |
+
parser.add_argument('--disable_fid', action='store_true', help='Disable FID evaluation.')
|
59 |
+
# W&B
|
60 |
+
parser.add_argument('--wandb', action='store_true', help='Log to Weights&Biases.')
|
61 |
+
parser.add_argument('--name', type=str, default=None, help='Name of this experiment in Weights&Biases.')
|
62 |
+
parser.add_argument('--entity', type=str, default='nyxai', help='Entity for this experiment in Weights&Biases.')
|
63 |
+
parser.add_argument('--group', type=str, default=None, help='Group name of this experiment for Weights&Biases.')
|
64 |
+
|
65 |
+
args = parser.parse_args()
|
66 |
+
|
67 |
+
# debug mode
|
68 |
+
if args.debug:
|
69 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
70 |
+
|
71 |
+
# some validation
|
72 |
+
if args.resume_run_id is not None:
|
73 |
+
assert args.load_from_pkl is None, 'When resuming a run one cannot also specify --load_from_pkl'
|
74 |
+
|
75 |
+
# set unique Run ID
|
76 |
+
if args.resume_run_id:
|
77 |
+
resume = 'must' # throw error if cannot find id to be resumed
|
78 |
+
args.run_id = args.resume_run_id
|
79 |
+
else:
|
80 |
+
resume = None # default
|
81 |
+
args.run_id = wandb.util.generate_id()
|
82 |
+
args.ckpt_dir = os.path.join(args.save_dir, args.run_id)
|
83 |
+
|
84 |
+
if jax.process_index() == 0:
|
85 |
+
if not args.ckpt_dir.startswith('gs://') and not os.path.exists(args.ckpt_dir):
|
86 |
+
os.makedirs(args.ckpt_dir)
|
87 |
+
if args.wandb:
|
88 |
+
wandb.init(id=args.run_id,
|
89 |
+
project=args.project,
|
90 |
+
group=args.group,
|
91 |
+
config=args,
|
92 |
+
name=args.name,
|
93 |
+
entity=args.entity,
|
94 |
+
resume=resume)
|
95 |
+
logger.info('Starting new run with config:')
|
96 |
+
print(json.dumps(vars(args), indent=4))
|
97 |
+
|
98 |
+
training.train_and_evaluate(args)
|
99 |
+
|
100 |
+
|
101 |
+
if __name__ == '__main__':
|
102 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
flaxmodels==0.1.1
|
2 |
+
flax==0.4.1
|
3 |
+
jax==0.3.14
|
4 |
+
tensorflow==2.4.1
|
5 |
+
optax==0.0.9
|
6 |
+
numpy
|
7 |
+
tensorflow-datasets
|
8 |
+
argparse
|
9 |
+
wandb
|
10 |
+
tqdm
|
11 |
+
dill
|
12 |
+
h5py
|
13 |
+
dataclasses
|
14 |
+
tqdm
|
stylegan2/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .generator import SynthesisNetwork
|
2 |
+
from .generator import MappingNetwork
|
3 |
+
from .generator import Generator
|
4 |
+
from .discriminator import Discriminator
|
5 |
+
|
stylegan2/discriminator.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import jax
|
3 |
+
from jax import random
|
4 |
+
import jax.numpy as jnp
|
5 |
+
import flax.linen as nn
|
6 |
+
from typing import Any, Tuple, List, Callable
|
7 |
+
import h5py
|
8 |
+
from . import ops
|
9 |
+
from stylegan2 import utils
|
10 |
+
|
11 |
+
|
12 |
+
URLS = {'afhqcat': 'https://www.dropbox.com/s/qygbjkefyqyu9k9/stylegan2_discriminator_afhqcat.h5?dl=1',
|
13 |
+
'afhqdog': 'https://www.dropbox.com/s/kmoxbp33qswz64p/stylegan2_discriminator_afhqdog.h5?dl=1',
|
14 |
+
'afhqwild': 'https://www.dropbox.com/s/jz1hpsyt3isj6e7/stylegan2_discriminator_afhqwild.h5?dl=1',
|
15 |
+
'brecahad': 'https://www.dropbox.com/s/h0cb89hruo6pmyj/stylegan2_discriminator_brecahad.h5?dl=1',
|
16 |
+
'car': 'https://www.dropbox.com/s/2ghjrmxih7cic76/stylegan2_discriminator_car.h5?dl=1',
|
17 |
+
'cat': 'https://www.dropbox.com/s/zfhjsvlsny5qixd/stylegan2_discriminator_cat.h5?dl=1',
|
18 |
+
'church': 'https://www.dropbox.com/s/jlno7zeivkjtk8g/stylegan2_discriminator_church.h5?dl=1',
|
19 |
+
'cifar10': 'https://www.dropbox.com/s/eldpubfkl4c6rur/stylegan2_discriminator_cifar10.h5?dl=1',
|
20 |
+
'ffhq': 'https://www.dropbox.com/s/m42qy9951b7lq1s/stylegan2_discriminator_ffhq.h5?dl=1',
|
21 |
+
'horse': 'https://www.dropbox.com/s/19f5pxrcdh2g8cw/stylegan2_discriminator_horse.h5?dl=1',
|
22 |
+
'metfaces': 'https://www.dropbox.com/s/xnokaunql12glkd/stylegan2_discriminator_metfaces.h5?dl=1'}
|
23 |
+
|
24 |
+
RESOLUTION = {'metfaces': 1024,
|
25 |
+
'ffhq': 1024,
|
26 |
+
'church': 256,
|
27 |
+
'cat': 256,
|
28 |
+
'horse': 256,
|
29 |
+
'car': 512,
|
30 |
+
'brecahad': 512,
|
31 |
+
'afhqwild': 512,
|
32 |
+
'afhqdog': 512,
|
33 |
+
'afhqcat': 512,
|
34 |
+
'cifar10': 32}
|
35 |
+
|
36 |
+
C_DIM = {'metfaces': 0,
|
37 |
+
'ffhq': 0,
|
38 |
+
'church': 0,
|
39 |
+
'cat': 0,
|
40 |
+
'horse': 0,
|
41 |
+
'car': 0,
|
42 |
+
'brecahad': 0,
|
43 |
+
'afhqwild': 0,
|
44 |
+
'afhqdog': 0,
|
45 |
+
'afhqcat': 0,
|
46 |
+
'cifar10': 10}
|
47 |
+
|
48 |
+
ARCHITECTURE = {'metfaces': 'resnet',
|
49 |
+
'ffhq': 'resnet',
|
50 |
+
'church': 'resnet',
|
51 |
+
'cat': 'resnet',
|
52 |
+
'horse': 'resnet',
|
53 |
+
'car': 'resnet',
|
54 |
+
'brecahad': 'resnet',
|
55 |
+
'afhqwild': 'resnet',
|
56 |
+
'afhqdog': 'resnet',
|
57 |
+
'afhqcat': 'resnet',
|
58 |
+
'cifar10': 'orig'}
|
59 |
+
|
60 |
+
MBSTD_GROUP_SIZE = {'metfaces': None,
|
61 |
+
'ffhq': None,
|
62 |
+
'church': None,
|
63 |
+
'cat': None,
|
64 |
+
'horse': None,
|
65 |
+
'car': None,
|
66 |
+
'brecahad': None,
|
67 |
+
'afhqwild': None,
|
68 |
+
'afhqdog': None,
|
69 |
+
'afhqcat': None,
|
70 |
+
'cifar10': 32}
|
71 |
+
|
72 |
+
|
73 |
+
class FromRGBLayer(nn.Module):
|
74 |
+
"""
|
75 |
+
From RGB Layer.
|
76 |
+
|
77 |
+
Attributes:
|
78 |
+
fmaps (int): Number of output channels of the convolution.
|
79 |
+
kernel (int): Kernel size of the convolution.
|
80 |
+
lr_multiplier (float): Learning rate multiplier.
|
81 |
+
activation (str): Activation function: 'relu', 'lrelu', etc.
|
82 |
+
param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
|
83 |
+
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
|
84 |
+
dtype (str): Data dtype.
|
85 |
+
rng (jax.random.PRNGKey): PRNG for initialization.
|
86 |
+
"""
|
87 |
+
fmaps: int
|
88 |
+
kernel: int=1
|
89 |
+
lr_multiplier: float=1
|
90 |
+
activation: str='leaky_relu'
|
91 |
+
param_dict: h5py.Group=None
|
92 |
+
clip_conv: float=None
|
93 |
+
dtype: str='float32'
|
94 |
+
rng: Any=random.PRNGKey(0)
|
95 |
+
|
96 |
+
@nn.compact
|
97 |
+
def __call__(self, x, y):
|
98 |
+
"""
|
99 |
+
Run From RGB Layer.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
x (tensor): Input image of shape [N, H, W, num_channels].
|
103 |
+
y (tensor): Input tensor of shape [N, H, W, out_channels].
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
(tensor): Output tensor of shape [N, H, W, out_channels].
|
107 |
+
"""
|
108 |
+
w_shape = [self.kernel, self.kernel, x.shape[3], self.fmaps]
|
109 |
+
w, b = ops.get_weight(w_shape, self.lr_multiplier, True, self.param_dict, 'fromrgb', self.rng)
|
110 |
+
|
111 |
+
w = self.param(name='weight', init_fn=lambda *_ : w)
|
112 |
+
b = self.param(name='bias', init_fn=lambda *_ : b)
|
113 |
+
w = ops.equalize_lr_weight(w, self.lr_multiplier)
|
114 |
+
b = ops.equalize_lr_bias(b, self.lr_multiplier)
|
115 |
+
|
116 |
+
x = x.astype(self.dtype)
|
117 |
+
x = ops.conv2d(x, w.astype(x.dtype))
|
118 |
+
x += b.astype(x.dtype)
|
119 |
+
x = ops.apply_activation(x, activation=self.activation)
|
120 |
+
if self.clip_conv is not None:
|
121 |
+
x = jnp.clip(x, -self.clip_conv, self.clip_conv)
|
122 |
+
if y is not None:
|
123 |
+
x += y
|
124 |
+
return x
|
125 |
+
|
126 |
+
|
127 |
+
class DiscriminatorLayer(nn.Module):
|
128 |
+
"""
|
129 |
+
Discriminator Layer.
|
130 |
+
|
131 |
+
Attributes:
|
132 |
+
fmaps (int): Number of output channels of the convolution.
|
133 |
+
kernel (int): Kernel size of the convolution.
|
134 |
+
use_bias (bool): If True, use bias.
|
135 |
+
down (bool): If True, downsample the spatial resolution.
|
136 |
+
resample_kernel (Tuple): Kernel that is used for FIR filter.
|
137 |
+
activation (str): Activation function: 'relu', 'lrelu', etc.
|
138 |
+
layer_name (str): Layer name.
|
139 |
+
param_dict (h5py.Group): Parameter dict with pretrained parameters.
|
140 |
+
lr_multiplier (float): Learning rate multiplier.
|
141 |
+
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
|
142 |
+
dtype (str): Data dtype.
|
143 |
+
rng (jax.random.PRNGKey): PRNG for initialization.
|
144 |
+
"""
|
145 |
+
fmaps: int
|
146 |
+
kernel: int=3
|
147 |
+
use_bias: bool=True
|
148 |
+
down: bool=False
|
149 |
+
resample_kernel: Tuple=None
|
150 |
+
activation: str='leaky_relu'
|
151 |
+
layer_name: str=None
|
152 |
+
param_dict: h5py.Group=None
|
153 |
+
lr_multiplier: float=1
|
154 |
+
clip_conv: float=None
|
155 |
+
dtype: str='float32'
|
156 |
+
rng: Any=random.PRNGKey(0)
|
157 |
+
|
158 |
+
@nn.compact
|
159 |
+
def __call__(self, x):
|
160 |
+
"""
|
161 |
+
Run Discriminator Layer.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
x (tensor): Input tensor of shape [N, H, W, C].
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
(tensor): Output tensor of shape [N, H, W, fmaps].
|
168 |
+
"""
|
169 |
+
w_shape = [self.kernel, self.kernel, x.shape[3], self.fmaps]
|
170 |
+
if self.use_bias:
|
171 |
+
w, b = ops.get_weight(w_shape, self.lr_multiplier, self.use_bias, self.param_dict, self.layer_name, self.rng)
|
172 |
+
else:
|
173 |
+
w = ops.get_weight(w_shape, self.lr_multiplier, self.use_bias, self.param_dict, self.layer_name, self.rng)
|
174 |
+
|
175 |
+
w = self.param(name='weight', init_fn=lambda *_ : w)
|
176 |
+
w = ops.equalize_lr_weight(w, self.lr_multiplier)
|
177 |
+
if self.use_bias:
|
178 |
+
b = self.param(name='bias', init_fn=lambda *_ : b)
|
179 |
+
b = ops.equalize_lr_bias(b, self.lr_multiplier)
|
180 |
+
|
181 |
+
x = x.astype(self.dtype)
|
182 |
+
x = ops.conv2d(x, w, down=self.down, resample_kernel=self.resample_kernel)
|
183 |
+
if self.use_bias: x += b.astype(x.dtype)
|
184 |
+
x = ops.apply_activation(x, activation=self.activation)
|
185 |
+
if self.clip_conv is not None:
|
186 |
+
x = jnp.clip(x, -self.clip_conv, self.clip_conv)
|
187 |
+
return x
|
188 |
+
|
189 |
+
|
190 |
+
class DiscriminatorBlock(nn.Module):
|
191 |
+
"""
|
192 |
+
Discriminator Block.
|
193 |
+
|
194 |
+
Attributes:
|
195 |
+
fmaps (int): Number of output channels of the convolution.
|
196 |
+
kernel (int): Kernel size of the convolution.
|
197 |
+
resample_kernel (Tuple): Kernel that is used for FIR filter.
|
198 |
+
activation (str): Activation function: 'relu', 'lrelu', etc.
|
199 |
+
param_dict (h5py.Group): Parameter dict with pretrained parameters.
|
200 |
+
lr_multiplier (float): Learning rate multiplier.
|
201 |
+
architecture (str): Architecture: 'orig', 'resnet'.
|
202 |
+
nf (Callable): Callable that returns the number of feature maps for a given layer.
|
203 |
+
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
|
204 |
+
dtype (str): Data dtype.
|
205 |
+
rng (jax.random.PRNGKey): Random seed for initialization.
|
206 |
+
"""
|
207 |
+
res: int
|
208 |
+
kernel: int=3
|
209 |
+
resample_kernel: Tuple=(1, 3, 3, 1)
|
210 |
+
activation: str='leaky_relu'
|
211 |
+
param_dict: Any=None
|
212 |
+
lr_multiplier: float=1
|
213 |
+
architecture: str='resnet'
|
214 |
+
nf: Callable=None
|
215 |
+
clip_conv: float=None
|
216 |
+
dtype: str='float32'
|
217 |
+
rng: Any=random.PRNGKey(0)
|
218 |
+
|
219 |
+
@nn.compact
|
220 |
+
def __call__(self, x):
|
221 |
+
"""
|
222 |
+
Run Discriminator Block.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
x (tensor): Input tensor of shape [N, H, W, C].
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
(tensor): Output tensor of shape [N, H, W, fmaps].
|
229 |
+
"""
|
230 |
+
init_rng = self.rng
|
231 |
+
x = x.astype(self.dtype)
|
232 |
+
residual = x
|
233 |
+
for i in range(2):
|
234 |
+
init_rng, init_key = random.split(init_rng)
|
235 |
+
x = DiscriminatorLayer(fmaps=self.nf(self.res - (i + 1)),
|
236 |
+
kernel=self.kernel,
|
237 |
+
down=i == 1,
|
238 |
+
resample_kernel=self.resample_kernel if i == 1 else None,
|
239 |
+
activation=self.activation,
|
240 |
+
layer_name=f'conv{i}',
|
241 |
+
param_dict=self.param_dict,
|
242 |
+
lr_multiplier=self.lr_multiplier,
|
243 |
+
clip_conv=self.clip_conv,
|
244 |
+
dtype=self.dtype,
|
245 |
+
rng=init_key)(x)
|
246 |
+
|
247 |
+
|
248 |
+
if self.architecture == 'resnet':
|
249 |
+
init_rng, init_key = random.split(init_rng)
|
250 |
+
residual = DiscriminatorLayer(fmaps=self.nf(self.res - 2),
|
251 |
+
kernel=1,
|
252 |
+
use_bias=False,
|
253 |
+
down=True,
|
254 |
+
resample_kernel=self.resample_kernel,
|
255 |
+
activation='linear',
|
256 |
+
layer_name='skip',
|
257 |
+
param_dict=self.param_dict,
|
258 |
+
lr_multiplier=self.lr_multiplier,
|
259 |
+
dtype=self.dtype,
|
260 |
+
rng=init_key)(residual)
|
261 |
+
|
262 |
+
x = (x + residual) * np.sqrt(0.5, dtype=x.dtype)
|
263 |
+
return x
|
264 |
+
|
265 |
+
|
266 |
+
class Discriminator(nn.Module):
|
267 |
+
"""
|
268 |
+
Discriminator.
|
269 |
+
|
270 |
+
Attributes:
|
271 |
+
resolution (int): Input resolution. Overridden based on dataset.
|
272 |
+
num_channels (int): Number of input color channels. Overridden based on dataset.
|
273 |
+
c_dim (int): Dimensionality of the labels (c), 0 if no labels. Overrttten based on dataset.
|
274 |
+
fmap_base (int): Overall multiplier for the number of feature maps.
|
275 |
+
fmap_decay (int): Log2 feature map reduction when doubling the resolution.
|
276 |
+
fmap_min (int): Minimum number of feature maps in any layer.
|
277 |
+
fmap_max (int): Maximum number of feature maps in any layer.
|
278 |
+
mapping_layers (int): Number of additional mapping layers for the conditioning labels.
|
279 |
+
mapping_fmaps (int): Number of activations in the mapping layers, None = default.
|
280 |
+
mapping_lr_multiplier (float): Learning rate multiplier for the mapping layers.
|
281 |
+
architecture (str): Architecture: 'orig', 'resnet'.
|
282 |
+
activation (int): Activation function: 'relu', 'leaky_relu', etc.
|
283 |
+
mbstd_group_size (int): Group size for the minibatch standard deviation layer, None = entire minibatch.
|
284 |
+
mbstd_num_features (int): Number of features for the minibatch standard deviation layer, 0 = disable.
|
285 |
+
resample_kernel (Tuple): Low-pass filter to apply when resampling activations, None = box filter.
|
286 |
+
num_fp16_res (int): Use float16 for the 'num_fp16_res' highest resolutions.
|
287 |
+
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
|
288 |
+
pretrained (str): Use pretrained model, None for random initialization.
|
289 |
+
ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used.
|
290 |
+
dtype (str): Data type.
|
291 |
+
rng (jax.random.PRNGKey): PRNG for initialization.
|
292 |
+
"""
|
293 |
+
# Input dimensions.
|
294 |
+
resolution: int=1024
|
295 |
+
num_channels: int=3
|
296 |
+
c_dim: int=0
|
297 |
+
|
298 |
+
# Capacity.
|
299 |
+
fmap_base: int=16384
|
300 |
+
fmap_decay: int=1
|
301 |
+
fmap_min: int=1
|
302 |
+
fmap_max: int=512
|
303 |
+
|
304 |
+
# Internal details.
|
305 |
+
mapping_layers: int=0
|
306 |
+
mapping_fmaps: int=None
|
307 |
+
mapping_lr_multiplier: float=0.1
|
308 |
+
architecture: str='resnet'
|
309 |
+
activation: str='leaky_relu'
|
310 |
+
mbstd_group_size: int=None
|
311 |
+
mbstd_num_features: int=1
|
312 |
+
resample_kernel: Tuple=(1, 3, 3, 1)
|
313 |
+
num_fp16_res: int=0
|
314 |
+
clip_conv: float=None
|
315 |
+
|
316 |
+
# Pretraining
|
317 |
+
pretrained: str=None
|
318 |
+
ckpt_dir: str=None
|
319 |
+
|
320 |
+
dtype: str='float32'
|
321 |
+
rng: Any=random.PRNGKey(0)
|
322 |
+
|
323 |
+
def setup(self):
|
324 |
+
self.resolution_ = self.resolution
|
325 |
+
self.c_dim_ = self.c_dim
|
326 |
+
self.architecture_ = self.architecture
|
327 |
+
self.mbstd_group_size_ = self.mbstd_group_size
|
328 |
+
self.param_dict = None
|
329 |
+
if self.pretrained is not None:
|
330 |
+
assert self.pretrained in URLS.keys(), f'Pretrained model not available: {self.pretrained}'
|
331 |
+
ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained])
|
332 |
+
self.param_dict = h5py.File(ckpt_file, 'r')['discriminator']
|
333 |
+
self.resolution_ = RESOLUTION[self.pretrained]
|
334 |
+
self.architecture_ = ARCHITECTURE[self.pretrained]
|
335 |
+
self.mbstd_group_size_ = MBSTD_GROUP_SIZE[self.pretrained]
|
336 |
+
self.c_dim_ = C_DIM[self.pretrained]
|
337 |
+
|
338 |
+
assert self.architecture in ['orig', 'resnet']
|
339 |
+
|
340 |
+
@nn.compact
|
341 |
+
def __call__(self, x, c=None):
|
342 |
+
"""
|
343 |
+
Run Discriminator.
|
344 |
+
|
345 |
+
Args:
|
346 |
+
x (tensor): Input image of shape [N, H, W, num_channels].
|
347 |
+
c (tensor): Input labels, shape [N, c_dim].
|
348 |
+
|
349 |
+
Returns:
|
350 |
+
(tensor): Output tensor of shape [N, 1].
|
351 |
+
"""
|
352 |
+
resolution_log2 = int(np.log2(self.resolution_))
|
353 |
+
assert self.resolution_ == 2**resolution_log2 and self.resolution_ >= 4
|
354 |
+
def nf(stage): return np.clip(int(self.fmap_base / (2.0 ** (stage * self.fmap_decay))), self.fmap_min, self.fmap_max)
|
355 |
+
if self.mapping_fmaps is None:
|
356 |
+
mapping_fmaps = nf(0)
|
357 |
+
else:
|
358 |
+
mapping_fmaps = self.mapping_fmaps
|
359 |
+
|
360 |
+
init_rng = self.rng
|
361 |
+
# Label embedding and mapping.
|
362 |
+
if self.c_dim_ > 0:
|
363 |
+
c = ops.LinearLayer(in_features=self.c_dim_,
|
364 |
+
out_features=mapping_fmaps,
|
365 |
+
lr_multiplier=self.mapping_lr_multiplier,
|
366 |
+
param_dict=self.param_dict,
|
367 |
+
layer_name='label_embedding',
|
368 |
+
dtype=self.dtype,
|
369 |
+
rng=init_rng)(c)
|
370 |
+
|
371 |
+
c = ops.normalize_2nd_moment(c)
|
372 |
+
for i in range(self.mapping_layers):
|
373 |
+
init_rng, init_key = random.split(init_rng)
|
374 |
+
c = ops.LinearLayer(in_features=self.c_dim_,
|
375 |
+
out_features=mapping_fmaps,
|
376 |
+
lr_multiplier=self.mapping_lr_multiplier,
|
377 |
+
param_dict=self.param_dict,
|
378 |
+
layer_name=f'fc{i}',
|
379 |
+
dtype=self.dtype,
|
380 |
+
rng=init_key)(c)
|
381 |
+
|
382 |
+
# Layers for >=8x8 resolutions.
|
383 |
+
y = None
|
384 |
+
for res in range(resolution_log2, 2, -1):
|
385 |
+
res_str = f'block_{2**res}x{2**res}'
|
386 |
+
if res == resolution_log2:
|
387 |
+
init_rng, init_key = random.split(init_rng)
|
388 |
+
x = FromRGBLayer(fmaps=nf(res - 1),
|
389 |
+
kernel=1,
|
390 |
+
activation=self.activation,
|
391 |
+
param_dict=self.param_dict[res_str] if self.param_dict is not None else None,
|
392 |
+
clip_conv=self.clip_conv,
|
393 |
+
dtype=self.dtype if res >= resolution_log2 + 1 - self.num_fp16_res else 'float32',
|
394 |
+
rng=init_key)(x, y)
|
395 |
+
|
396 |
+
init_rng, init_key = random.split(init_rng)
|
397 |
+
x = DiscriminatorBlock(res=res,
|
398 |
+
kernel=3,
|
399 |
+
resample_kernel=self.resample_kernel,
|
400 |
+
activation=self.activation,
|
401 |
+
param_dict=self.param_dict[res_str] if self.param_dict is not None else None,
|
402 |
+
architecture=self.architecture_,
|
403 |
+
nf=nf,
|
404 |
+
clip_conv=self.clip_conv,
|
405 |
+
dtype=self.dtype if res >= resolution_log2 + 1 - self.num_fp16_res else 'float32',
|
406 |
+
rng=init_key)(x)
|
407 |
+
|
408 |
+
# Layers for 4x4 resolution.
|
409 |
+
dtype = jnp.float32
|
410 |
+
x = x.astype(dtype)
|
411 |
+
if self.mbstd_num_features > 0:
|
412 |
+
x = ops.minibatch_stddev_layer(x, self.mbstd_group_size_, self.mbstd_num_features)
|
413 |
+
init_rng, init_key = random.split(init_rng)
|
414 |
+
x = DiscriminatorLayer(fmaps=nf(1),
|
415 |
+
kernel=3,
|
416 |
+
use_bias=True,
|
417 |
+
activation=self.activation,
|
418 |
+
layer_name='conv0',
|
419 |
+
param_dict=self.param_dict['block_4x4'] if self.param_dict is not None else None,
|
420 |
+
clip_conv=self.clip_conv,
|
421 |
+
dtype=dtype,
|
422 |
+
rng=init_rng)(x)
|
423 |
+
|
424 |
+
# Switch to NCHW so that the pretrained weights still work after reshaping
|
425 |
+
x = jnp.transpose(x, axes=(0, 3, 1, 2))
|
426 |
+
x = jnp.reshape(x, newshape=(-1, x.shape[1] * x.shape[2] * x.shape[3]))
|
427 |
+
|
428 |
+
init_rng, init_key = random.split(init_rng)
|
429 |
+
x = ops.LinearLayer(in_features=x.shape[1],
|
430 |
+
out_features=nf(0),
|
431 |
+
activation=self.activation,
|
432 |
+
param_dict=self.param_dict['block_4x4'] if self.param_dict is not None else None,
|
433 |
+
layer_name='fc0',
|
434 |
+
dtype=dtype,
|
435 |
+
rng=init_key)(x)
|
436 |
+
|
437 |
+
# Output layer.
|
438 |
+
init_rng, init_key = random.split(init_rng)
|
439 |
+
x = ops.LinearLayer(in_features=x.shape[1],
|
440 |
+
out_features=1 if self.c_dim_ == 0 else mapping_fmaps,
|
441 |
+
param_dict=self.param_dict,
|
442 |
+
layer_name='output',
|
443 |
+
dtype=dtype,
|
444 |
+
rng=init_key)(x)
|
445 |
+
|
446 |
+
if self.c_dim_ > 0:
|
447 |
+
x = jnp.sum(x * c, axis=1, keepdims=True) / jnp.sqrt(mapping_fmaps)
|
448 |
+
return x
|
449 |
+
|
450 |
+
|
451 |
+
|
stylegan2/generator.py
ADDED
@@ -0,0 +1,713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import jax
|
3 |
+
from jax import random
|
4 |
+
import jax.numpy as jnp
|
5 |
+
import flax.linen as nn
|
6 |
+
from typing import Any, Tuple, List
|
7 |
+
import h5py
|
8 |
+
from . import ops
|
9 |
+
from stylegan2 import utils
|
10 |
+
|
11 |
+
|
12 |
+
URLS = {'afhqcat': 'https://www.dropbox.com/s/lv1r0bwvg5ta51f/stylegan2_generator_afhqcat.h5?dl=1',
|
13 |
+
'afhqdog': 'https://www.dropbox.com/s/px6ply9hv0vdwen/stylegan2_generator_afhqdog.h5?dl=1',
|
14 |
+
'afhqwild': 'https://www.dropbox.com/s/p1slbtmzhcnw9q8/stylegan2_generator_afhqwild.h5?dl=1',
|
15 |
+
'brecahad': 'https://www.dropbox.com/s/28uykhj0ku6hwg2/stylegan2_generator_brecahad.h5?dl=1',
|
16 |
+
'car': 'https://www.dropbox.com/s/67o834b6xfg9x1q/stylegan2_generator_car.h5?dl=1',
|
17 |
+
'cat': 'https://www.dropbox.com/s/cu9egc4e74e1nig/stylegan2_generator_cat.h5?dl=1',
|
18 |
+
'church': 'https://www.dropbox.com/s/kwvokfwbrhsn58m/stylegan2_generator_church.h5?dl=1',
|
19 |
+
'cifar10': 'https://www.dropbox.com/s/h1kmymjzfwwkftk/stylegan2_generator_cifar10.h5?dl=1',
|
20 |
+
'ffhq': 'https://www.dropbox.com/s/e8de1peq7p8gq9d/stylegan2_generator_ffhq.h5?dl=1',
|
21 |
+
'horse': 'https://www.dropbox.com/s/3e5bimv2d41bc13/stylegan2_generator_horse.h5?dl=1',
|
22 |
+
'metfaces': 'https://www.dropbox.com/s/75klr5k6mgm7qdy/stylegan2_generator_metfaces.h5?dl=1'}
|
23 |
+
|
24 |
+
RESOLUTION = {'metfaces': 1024,
|
25 |
+
'ffhq': 1024,
|
26 |
+
'church': 256,
|
27 |
+
'cat': 256,
|
28 |
+
'horse': 256,
|
29 |
+
'car': 512,
|
30 |
+
'brecahad': 512,
|
31 |
+
'afhqwild': 512,
|
32 |
+
'afhqdog': 512,
|
33 |
+
'afhqcat': 512,
|
34 |
+
'cifar10': 32}
|
35 |
+
|
36 |
+
C_DIM = {'metfaces': 0,
|
37 |
+
'ffhq': 0,
|
38 |
+
'church': 0,
|
39 |
+
'cat': 0,
|
40 |
+
'horse': 0,
|
41 |
+
'car': 0,
|
42 |
+
'brecahad': 0,
|
43 |
+
'afhqwild': 0,
|
44 |
+
'afhqdog': 0,
|
45 |
+
'afhqcat': 0,
|
46 |
+
'cifar10': 10}
|
47 |
+
|
48 |
+
NUM_MAPPING_LAYERS = {'metfaces': 8,
|
49 |
+
'ffhq': 8,
|
50 |
+
'church': 8,
|
51 |
+
'cat': 8,
|
52 |
+
'horse': 8,
|
53 |
+
'car': 8,
|
54 |
+
'brecahad': 8,
|
55 |
+
'afhqwild': 8,
|
56 |
+
'afhqdog': 8,
|
57 |
+
'afhqcat': 8,
|
58 |
+
'cifar10': 2}
|
59 |
+
|
60 |
+
|
61 |
+
class MappingNetwork(nn.Module):
|
62 |
+
"""
|
63 |
+
Mapping Network.
|
64 |
+
|
65 |
+
Attributes:
|
66 |
+
z_dim (int): Input latent (Z) dimensionality.
|
67 |
+
c_dim (int): Conditioning label (C) dimensionality, 0 = no label.
|
68 |
+
w_dim (int): Intermediate latent (W) dimensionality.
|
69 |
+
embed_features (int): Label embedding dimensionality, None = same as w_dim.
|
70 |
+
layer_features (int): Number of intermediate features in the mapping layers, None = same as w_dim.
|
71 |
+
num_ws (int): Number of intermediate latents to output, None = do not broadcast.
|
72 |
+
num_layers (int): Number of mapping layers.
|
73 |
+
pretrained (str): Which pretrained model to use, None for random initialization.
|
74 |
+
param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
|
75 |
+
ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used.
|
76 |
+
activation (str): Activation function: 'relu', 'lrelu', etc.
|
77 |
+
lr_multiplier (float): Learning rate multiplier for the mapping layers.
|
78 |
+
w_avg_beta (float): Decay for tracking the moving average of W during training, None = do not track.
|
79 |
+
dtype (str): Data type.
|
80 |
+
rng (jax.random.PRNGKey): PRNG for initialization.
|
81 |
+
"""
|
82 |
+
# Dimensionality
|
83 |
+
z_dim: int=512
|
84 |
+
c_dim: int=0
|
85 |
+
w_dim: int=512
|
86 |
+
embed_features: int=None
|
87 |
+
layer_features: int=512
|
88 |
+
|
89 |
+
# Layers
|
90 |
+
num_ws: int=18
|
91 |
+
num_layers: int=8
|
92 |
+
|
93 |
+
# Pretrained
|
94 |
+
pretrained: str=None
|
95 |
+
param_dict: h5py.Group=None
|
96 |
+
ckpt_dir: str=None
|
97 |
+
|
98 |
+
# Internal details
|
99 |
+
activation: str='leaky_relu'
|
100 |
+
lr_multiplier: float=0.01
|
101 |
+
w_avg_beta: float=0.995
|
102 |
+
dtype: str='float32'
|
103 |
+
rng: Any=random.PRNGKey(0)
|
104 |
+
|
105 |
+
def setup(self):
|
106 |
+
self.embed_features_ = self.embed_features
|
107 |
+
self.c_dim_ = self.c_dim
|
108 |
+
self.layer_features_ = self.layer_features
|
109 |
+
self.num_layers_ = self.num_layers
|
110 |
+
self.param_dict_ = self.param_dict
|
111 |
+
|
112 |
+
if self.pretrained is not None and self.param_dict is None:
|
113 |
+
assert self.pretrained in URLS.keys(), f'Pretrained model not available: {self.pretrained}'
|
114 |
+
ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained])
|
115 |
+
self.param_dict_ = h5py.File(ckpt_file, 'r')['mapping_network']
|
116 |
+
self.c_dim_ = C_DIM[self.pretrained]
|
117 |
+
self.num_layers_ = NUM_MAPPING_LAYERS[self.pretrained]
|
118 |
+
|
119 |
+
if self.embed_features_ is None:
|
120 |
+
self.embed_features_ = self.w_dim
|
121 |
+
if self.c_dim_ == 0:
|
122 |
+
self.embed_features_ = 0
|
123 |
+
if self.layer_features_ is None:
|
124 |
+
self.layer_features_ = self.w_dim
|
125 |
+
|
126 |
+
if self.param_dict_ is not None and 'w_avg' in self.param_dict_:
|
127 |
+
self.w_avg = self.variable('moving_stats', 'w_avg', lambda *_ : jnp.array(self.param_dict_['w_avg']), [self.w_dim])
|
128 |
+
else:
|
129 |
+
self.w_avg = self.variable('moving_stats', 'w_avg', jnp.zeros, [self.w_dim])
|
130 |
+
|
131 |
+
@nn.compact
|
132 |
+
def __call__(self, z, c=None, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False, train=True):
|
133 |
+
"""
|
134 |
+
Run Mapping Network.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
z (tensor): Input noise, shape [N, z_dim].
|
138 |
+
c (tensor): Input labels, shape [N, c_dim].
|
139 |
+
truncation_psi (float): Controls truncation (trading off variation for quality). If 1, truncation is disabled.
|
140 |
+
truncation_cutoff (int): Controls truncation. None = disable.
|
141 |
+
skip_w_avg_update (bool): If True, updates the exponential moving average of W.
|
142 |
+
train (bool): Training mode.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
(tensor): Intermediate latent W.
|
146 |
+
"""
|
147 |
+
init_rng = self.rng
|
148 |
+
# Embed, normalize, and concat inputs.
|
149 |
+
x = None
|
150 |
+
if self.z_dim > 0:
|
151 |
+
x = ops.normalize_2nd_moment(z.astype(jnp.float32))
|
152 |
+
if self.c_dim_ > 0:
|
153 |
+
# Conditioning label
|
154 |
+
y = ops.LinearLayer(in_features=self.c_dim_,
|
155 |
+
out_features=self.embed_features_,
|
156 |
+
use_bias=True,
|
157 |
+
lr_multiplier=self.lr_multiplier,
|
158 |
+
activation='linear',
|
159 |
+
param_dict=self.param_dict_,
|
160 |
+
layer_name='label_embedding',
|
161 |
+
dtype=self.dtype,
|
162 |
+
rng=init_rng)(c.astype(jnp.float32))
|
163 |
+
|
164 |
+
y = ops.normalize_2nd_moment(y)
|
165 |
+
x = jnp.concatenate((x, y), axis=1) if x is not None else y
|
166 |
+
|
167 |
+
# Main layers.
|
168 |
+
for i in range(self.num_layers_):
|
169 |
+
init_rng, init_key = random.split(init_rng)
|
170 |
+
x = ops.LinearLayer(in_features=x.shape[1],
|
171 |
+
out_features=self.layer_features_,
|
172 |
+
use_bias=True,
|
173 |
+
lr_multiplier=self.lr_multiplier,
|
174 |
+
activation=self.activation,
|
175 |
+
param_dict=self.param_dict_,
|
176 |
+
layer_name=f'fc{i}',
|
177 |
+
dtype=self.dtype,
|
178 |
+
rng=init_key)(x)
|
179 |
+
|
180 |
+
# Update moving average of W.
|
181 |
+
if self.w_avg_beta is not None and train and not skip_w_avg_update:
|
182 |
+
self.w_avg.value = self.w_avg_beta * self.w_avg.value + (1 - self.w_avg_beta) * jnp.mean(x, axis=0)
|
183 |
+
|
184 |
+
# Broadcast.
|
185 |
+
if self.num_ws is not None:
|
186 |
+
x = jnp.repeat(jnp.expand_dims(x, axis=-2), repeats=self.num_ws, axis=-2)
|
187 |
+
|
188 |
+
# Apply truncation.
|
189 |
+
if truncation_psi != 1:
|
190 |
+
assert self.w_avg_beta is not None
|
191 |
+
if self.num_ws is None or truncation_cutoff is None:
|
192 |
+
x = truncation_psi * x + (1 - truncation_psi) * self.w_avg.value
|
193 |
+
else:
|
194 |
+
x[:, :truncation_cutoff] = truncation_psi * x[:, :truncation_cutoff] + (1 - truncation_psi) * self.w_avg.value
|
195 |
+
|
196 |
+
return x
|
197 |
+
|
198 |
+
|
199 |
+
class SynthesisLayer(nn.Module):
|
200 |
+
"""
|
201 |
+
Synthesis Layer.
|
202 |
+
|
203 |
+
Attributes:
|
204 |
+
fmaps (int): Number of output channels of the modulated convolution.
|
205 |
+
kernel (int): Kernel size of the modulated convolution.
|
206 |
+
layer_idx (int): Layer index. Used to access the latent code for a specific layer.
|
207 |
+
res (int): Resolution (log2) of the current layer.
|
208 |
+
lr_multiplier (float): Learning rate multiplier.
|
209 |
+
up (bool): If True, upsample the spatial resolution.
|
210 |
+
activation (str): Activation function: 'relu', 'lrelu', etc.
|
211 |
+
use_noise (bool): If True, add spatial-specific noise.
|
212 |
+
resample_kernel (Tuple): Kernel that is used for FIR filter.
|
213 |
+
fused_modconv (bool): If True, Perform modulation, convolution, and demodulation as a single fused operation.
|
214 |
+
param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
|
215 |
+
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
|
216 |
+
dtype (str): Data dtype.
|
217 |
+
rng (jax.random.PRNGKey): PRNG for initialization.
|
218 |
+
"""
|
219 |
+
fmaps: int
|
220 |
+
kernel: int
|
221 |
+
layer_idx: int
|
222 |
+
res: int
|
223 |
+
lr_multiplier: float=1
|
224 |
+
up: bool=False
|
225 |
+
activation: str='leaky_relu'
|
226 |
+
use_noise: bool=True
|
227 |
+
resample_kernel: Tuple=(1, 3, 3, 1)
|
228 |
+
fused_modconv: bool=False
|
229 |
+
param_dict: h5py.Group=None
|
230 |
+
clip_conv: float=None
|
231 |
+
dtype: str='float32'
|
232 |
+
rng: Any=random.PRNGKey(0)
|
233 |
+
|
234 |
+
def setup(self):
|
235 |
+
if self.param_dict is not None:
|
236 |
+
noise_const = jnp.array(self.param_dict['noise_const'], dtype=self.dtype)
|
237 |
+
else:
|
238 |
+
noise_const = random.normal(self.rng, shape=(1, 2 ** self.res, 2 ** self.res, 1), dtype=self.dtype)
|
239 |
+
self.noise_const = self.variable('noise_consts', 'noise_const', lambda *_: noise_const)
|
240 |
+
|
241 |
+
@nn.compact
|
242 |
+
def __call__(self, x, dlatents, noise_mode='random', rng=random.PRNGKey(0)):
|
243 |
+
"""
|
244 |
+
Run Synthesis Layer.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
x (tensor): Input tensor of the shape [N, H, W, C].
|
248 |
+
dlatents (tensor): Intermediate latents (W) of shape [N, num_ws, w_dim].
|
249 |
+
noise_mode (str): Noise type.
|
250 |
+
- 'const': Constant noise.
|
251 |
+
- 'random': Random noise.
|
252 |
+
- 'none': No noise.
|
253 |
+
rng (jax.random.PRNGKey): PRNG for spatialwise noise.
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
(tensor): Output tensor of shape [N, H', W', fmaps].
|
257 |
+
"""
|
258 |
+
assert noise_mode in ['const', 'random', 'none']
|
259 |
+
|
260 |
+
linear_rng, conv_rng = random.split(self.rng)
|
261 |
+
# Affine transformation to obtain style variable.
|
262 |
+
s = ops.LinearLayer(in_features=dlatents[:, self.layer_idx].shape[1],
|
263 |
+
out_features=x.shape[3],
|
264 |
+
use_bias=True,
|
265 |
+
bias_init=1,
|
266 |
+
lr_multiplier=self.lr_multiplier,
|
267 |
+
param_dict=self.param_dict,
|
268 |
+
layer_name='affine',
|
269 |
+
dtype=self.dtype,
|
270 |
+
rng=linear_rng)(dlatents[:, self.layer_idx])
|
271 |
+
|
272 |
+
# Noise variables.
|
273 |
+
if self.param_dict is None:
|
274 |
+
noise_strength = jnp.zeros(())
|
275 |
+
else:
|
276 |
+
noise_strength = jnp.array(self.param_dict['noise_strength'])
|
277 |
+
noise_strength = self.param(name='noise_strength', init_fn=lambda *_ : noise_strength)
|
278 |
+
|
279 |
+
# Weight and bias for convolution operation.
|
280 |
+
w_shape = [self.kernel, self.kernel, x.shape[3], self.fmaps]
|
281 |
+
w, b = ops.get_weight(w_shape, self.lr_multiplier, True, self.param_dict, 'conv', conv_rng)
|
282 |
+
w = self.param(name='weight', init_fn=lambda *_ : w)
|
283 |
+
b = self.param(name='bias', init_fn=lambda *_ : b)
|
284 |
+
w = ops.equalize_lr_weight(w, self.lr_multiplier)
|
285 |
+
b = ops.equalize_lr_bias(b, self.lr_multiplier)
|
286 |
+
|
287 |
+
x = ops.modulated_conv2d_layer(x=x,
|
288 |
+
w=w,
|
289 |
+
s=s,
|
290 |
+
fmaps=self.fmaps,
|
291 |
+
kernel=self.kernel,
|
292 |
+
up=self.up,
|
293 |
+
resample_kernel=self.resample_kernel,
|
294 |
+
fused_modconv=self.fused_modconv)
|
295 |
+
|
296 |
+
if self.use_noise and noise_mode != 'none':
|
297 |
+
if noise_mode == 'const':
|
298 |
+
noise = self.noise_const.value
|
299 |
+
elif noise_mode == 'random':
|
300 |
+
noise = random.normal(rng, shape=(x.shape[0], x.shape[1], x.shape[2], 1), dtype=self.dtype)
|
301 |
+
x += noise * noise_strength.astype(self.dtype)
|
302 |
+
x += b.astype(x.dtype)
|
303 |
+
x = ops.apply_activation(x, activation=self.activation)
|
304 |
+
if self.clip_conv is not None:
|
305 |
+
x = jnp.clip(x, -self.clip_conv, self.clip_conv)
|
306 |
+
return x
|
307 |
+
|
308 |
+
|
309 |
+
class ToRGBLayer(nn.Module):
|
310 |
+
"""
|
311 |
+
To RGB Layer.
|
312 |
+
|
313 |
+
Attributes:
|
314 |
+
fmaps (int): Number of output channels of the modulated convolution.
|
315 |
+
layer_idx (int): Layer index. Used to access the latent code for a specific layer.
|
316 |
+
kernel (int): Kernel size of the modulated convolution.
|
317 |
+
lr_multiplier (float): Learning rate multiplier.
|
318 |
+
fused_modconv (bool): If True, Perform modulation, convolution, and demodulation as a single fused operation.
|
319 |
+
param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
|
320 |
+
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
|
321 |
+
dtype (str): Data dtype.
|
322 |
+
rng (jax.random.PRNGKey): PRNG for initialization.
|
323 |
+
"""
|
324 |
+
fmaps: int
|
325 |
+
layer_idx: int
|
326 |
+
kernel: int=1
|
327 |
+
lr_multiplier: float=1
|
328 |
+
fused_modconv: bool=False
|
329 |
+
param_dict: h5py.Group=None
|
330 |
+
clip_conv: float=None
|
331 |
+
dtype: str='float32'
|
332 |
+
rng: Any=random.PRNGKey(0)
|
333 |
+
|
334 |
+
@nn.compact
|
335 |
+
def __call__(self, x, y, dlatents):
|
336 |
+
"""
|
337 |
+
Run To RGB Layer.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
x (tensor): Input tensor of shape [N, H, W, C].
|
341 |
+
y (tensor): Image of shape [N, H', W', fmaps].
|
342 |
+
dlatents (tensor): Intermediate latents (W) of shape [N, num_ws, w_dim].
|
343 |
+
|
344 |
+
Returns:
|
345 |
+
(tensor): Output tensor of shape [N, H', W', fmaps].
|
346 |
+
"""
|
347 |
+
# Affine transformation to obtain style variable.
|
348 |
+
s = ops.LinearLayer(in_features=dlatents[:, self.layer_idx].shape[1],
|
349 |
+
out_features=x.shape[3],
|
350 |
+
use_bias=True,
|
351 |
+
bias_init=1,
|
352 |
+
lr_multiplier=self.lr_multiplier,
|
353 |
+
param_dict=self.param_dict,
|
354 |
+
layer_name='affine',
|
355 |
+
dtype=self.dtype,
|
356 |
+
rng=self.rng)(dlatents[:, self.layer_idx])
|
357 |
+
|
358 |
+
# Weight and bias for convolution operation.
|
359 |
+
w_shape = [self.kernel, self.kernel, x.shape[3], self.fmaps]
|
360 |
+
w, b = ops.get_weight(w_shape, self.lr_multiplier, True, self.param_dict, 'conv', self.rng)
|
361 |
+
w = self.param(name='weight', init_fn=lambda *_ : w)
|
362 |
+
b = self.param(name='bias', init_fn=lambda *_ : b)
|
363 |
+
w = ops.equalize_lr_weight(w, self.lr_multiplier)
|
364 |
+
b = ops.equalize_lr_bias(b, self.lr_multiplier)
|
365 |
+
|
366 |
+
x = ops.modulated_conv2d_layer(x, w, s, fmaps=self.fmaps, kernel=self.kernel, demodulate=False, fused_modconv=self.fused_modconv)
|
367 |
+
x += b.astype(x.dtype)
|
368 |
+
x = ops.apply_activation(x, activation='linear')
|
369 |
+
if self.clip_conv is not None:
|
370 |
+
x = jnp.clip(x, -self.clip_conv, self.clip_conv)
|
371 |
+
if y is not None:
|
372 |
+
x += y.astype(jnp.float32)
|
373 |
+
return x
|
374 |
+
|
375 |
+
|
376 |
+
class SynthesisBlock(nn.Module):
|
377 |
+
"""
|
378 |
+
Synthesis Block.
|
379 |
+
|
380 |
+
Attributes:
|
381 |
+
fmaps (int): Number of output channels of the modulated convolution.
|
382 |
+
res (int): Resolution (log2) of the current block.
|
383 |
+
num_layers (int): Number of layers in the current block.
|
384 |
+
num_channels (int): Number of output color channels.
|
385 |
+
lr_multiplier (float): Learning rate multiplier.
|
386 |
+
activation (str): Activation function: 'relu', 'lrelu', etc.
|
387 |
+
use_noise (bool): If True, add spatial-specific noise.
|
388 |
+
resample_kernel (Tuple): Kernel that is used for FIR filter.
|
389 |
+
fused_modconv (bool): If True, Perform modulation, convolution, and demodulation as a single fused operation.
|
390 |
+
param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
|
391 |
+
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
|
392 |
+
dtype (str): Data dtype.
|
393 |
+
rng (jax.random.PRNGKey): PRNG for initialization.
|
394 |
+
"""
|
395 |
+
fmaps: int
|
396 |
+
res: int
|
397 |
+
num_layers: int=2
|
398 |
+
num_channels: int=3
|
399 |
+
lr_multiplier: float=1
|
400 |
+
activation: str='leaky_relu'
|
401 |
+
use_noise: bool=True
|
402 |
+
resample_kernel: Tuple=(1, 3, 3, 1)
|
403 |
+
fused_modconv: bool=False
|
404 |
+
param_dict: h5py.Group=None
|
405 |
+
clip_conv: float=None
|
406 |
+
dtype: str='float32'
|
407 |
+
rng: Any=random.PRNGKey(0)
|
408 |
+
|
409 |
+
@nn.compact
|
410 |
+
def __call__(self, x, y, dlatents_in, noise_mode='random', rng=random.PRNGKey(0)):
|
411 |
+
"""
|
412 |
+
Run Synthesis Block.
|
413 |
+
|
414 |
+
Args:
|
415 |
+
x (tensor): Input tensor of shape [N, H, W, C].
|
416 |
+
y (tensor): Image of shape [N, H', W', fmaps].
|
417 |
+
dlatents (tensor): Intermediate latents (W) of shape [N, num_ws, w_dim].
|
418 |
+
noise_mode (str): Noise type.
|
419 |
+
- 'const': Constant noise.
|
420 |
+
- 'random': Random noise.
|
421 |
+
- 'none': No noise.
|
422 |
+
rng (jax.random.PRNGKey): PRNG for spatialwise noise.
|
423 |
+
|
424 |
+
Returns:
|
425 |
+
(tensor): Output tensor of shape [N, H', W', fmaps].
|
426 |
+
"""
|
427 |
+
x = x.astype(self.dtype)
|
428 |
+
init_rng = self.rng
|
429 |
+
for i in range(self.num_layers):
|
430 |
+
init_rng, init_key = random.split(init_rng)
|
431 |
+
x = SynthesisLayer(fmaps=self.fmaps,
|
432 |
+
kernel=3,
|
433 |
+
layer_idx=self.res * 2 - (5 - i) if self.res > 2 else 0,
|
434 |
+
res=self.res,
|
435 |
+
lr_multiplier=self.lr_multiplier,
|
436 |
+
up=i == 0 and self.res != 2,
|
437 |
+
activation=self.activation,
|
438 |
+
use_noise=self.use_noise,
|
439 |
+
resample_kernel=self.resample_kernel,
|
440 |
+
fused_modconv=self.fused_modconv,
|
441 |
+
param_dict=self.param_dict[f'layer{i}'] if self.param_dict is not None else None,
|
442 |
+
dtype=self.dtype,
|
443 |
+
rng=init_key)(x, dlatents_in, noise_mode, rng)
|
444 |
+
|
445 |
+
if self.num_layers == 2:
|
446 |
+
k = ops.setup_filter(self.resample_kernel)
|
447 |
+
y = ops.upsample2d(y, f=k, up=2)
|
448 |
+
|
449 |
+
init_rng, init_key = random.split(init_rng)
|
450 |
+
y = ToRGBLayer(fmaps=self.num_channels,
|
451 |
+
layer_idx=self.res * 2 - 3,
|
452 |
+
lr_multiplier=self.lr_multiplier,
|
453 |
+
param_dict=self.param_dict['torgb'] if self.param_dict is not None else None,
|
454 |
+
dtype=self.dtype,
|
455 |
+
rng=init_key)(x, y, dlatents_in)
|
456 |
+
return x, y
|
457 |
+
|
458 |
+
|
459 |
+
class SynthesisNetwork(nn.Module):
|
460 |
+
"""
|
461 |
+
Synthesis Network.
|
462 |
+
|
463 |
+
Attributes:
|
464 |
+
resolution (int): Output resolution.
|
465 |
+
num_channels (int): Number of output color channels.
|
466 |
+
w_dim (int): Input latent (Z) dimensionality.
|
467 |
+
fmap_base (int): Overall multiplier for the number of feature maps.
|
468 |
+
fmap_decay (int): Log2 feature map reduction when doubling the resolution.
|
469 |
+
fmap_min (int): Minimum number of feature maps in any layer.
|
470 |
+
fmap_max (int): Maximum number of feature maps in any layer.
|
471 |
+
fmap_const (int): Number of feature maps in the constant input layer. None = default.
|
472 |
+
pretrained (str): Which pretrained model to use, None for random initialization.
|
473 |
+
param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
|
474 |
+
ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used.
|
475 |
+
activation (str): Activation function: 'relu', 'lrelu', etc.
|
476 |
+
use_noise (bool): If True, add spatial-specific noise.
|
477 |
+
resample_kernel (Tuple): Kernel that is used for FIR filter.
|
478 |
+
fused_modconv (bool): If True, Perform modulation, convolution, and demodulation as a single fused operation.
|
479 |
+
num_fp16_res (int): Use float16 for the 'num_fp16_res' highest resolutions.
|
480 |
+
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
|
481 |
+
dtype (str): Data type.
|
482 |
+
rng (jax.random.PRNGKey): PRNG for initialization.
|
483 |
+
"""
|
484 |
+
# Dimensionality
|
485 |
+
resolution: int=1024
|
486 |
+
num_channels: int=3
|
487 |
+
w_dim: int=512
|
488 |
+
|
489 |
+
# Capacity
|
490 |
+
fmap_base: int=16384
|
491 |
+
fmap_decay: int=1
|
492 |
+
fmap_min: int=1
|
493 |
+
fmap_max: int=512
|
494 |
+
fmap_const: int=None
|
495 |
+
|
496 |
+
# Pretraining
|
497 |
+
pretrained: str=None
|
498 |
+
param_dict: h5py.Group=None
|
499 |
+
ckpt_dir: str=None
|
500 |
+
|
501 |
+
# Internal details
|
502 |
+
activation: str='leaky_relu'
|
503 |
+
use_noise: bool=True
|
504 |
+
resample_kernel: Tuple=(1, 3, 3, 1)
|
505 |
+
fused_modconv: bool=False
|
506 |
+
num_fp16_res: int=0
|
507 |
+
clip_conv: float=None
|
508 |
+
dtype: str='float32'
|
509 |
+
rng: Any=random.PRNGKey(0)
|
510 |
+
|
511 |
+
def setup(self):
|
512 |
+
self.resolution_ = self.resolution
|
513 |
+
self.param_dict_ = self.param_dict
|
514 |
+
if self.pretrained is not None and self.param_dict is None:
|
515 |
+
assert self.pretrained in URLS.keys(), f'Pretrained model not available: {self.pretrained}'
|
516 |
+
ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained])
|
517 |
+
self.param_dict_ = h5py.File(ckpt_file, 'r')['synthesis_network']
|
518 |
+
self.resolution_ = RESOLUTION[self.pretrained]
|
519 |
+
|
520 |
+
@nn.compact
|
521 |
+
def __call__(self, dlatents_in, noise_mode='random', rng=random.PRNGKey(0)):
|
522 |
+
"""
|
523 |
+
Run Synthesis Network.
|
524 |
+
|
525 |
+
Args:
|
526 |
+
dlatents_in (tensor): Intermediate latents (W) of shape [N, num_ws, w_dim].
|
527 |
+
noise_mode (str): Noise type.
|
528 |
+
- 'const': Constant noise.
|
529 |
+
- 'random': Random noise.
|
530 |
+
- 'none': No noise.
|
531 |
+
rng (jax.random.PRNGKey): PRNG for spatialwise noise.
|
532 |
+
|
533 |
+
Returns:
|
534 |
+
(tensor): Image of shape [N, H, W, num_channels].
|
535 |
+
"""
|
536 |
+
resolution_log2 = int(np.log2(self.resolution_))
|
537 |
+
assert self.resolution_ == 2 ** resolution_log2 and self.resolution_ >= 4
|
538 |
+
|
539 |
+
def nf(stage): return np.clip(int(self.fmap_base / (2.0 ** (stage * self.fmap_decay))), self.fmap_min, self.fmap_max)
|
540 |
+
num_layers = resolution_log2 * 2 - 2
|
541 |
+
|
542 |
+
fmaps = self.fmap_const if self.fmap_const is not None else nf(1)
|
543 |
+
|
544 |
+
if self.param_dict_ is None:
|
545 |
+
const = random.normal(self.rng, (1, 4, 4, fmaps), dtype=self.dtype)
|
546 |
+
else:
|
547 |
+
const = jnp.array(self.param_dict_['const'], dtype=self.dtype)
|
548 |
+
x = self.param(name='const', init_fn=lambda *_ : const)
|
549 |
+
x = jnp.repeat(x, repeats=dlatents_in.shape[0], axis=0)
|
550 |
+
|
551 |
+
y = None
|
552 |
+
|
553 |
+
dlatents_in = dlatents_in.astype(jnp.float32)
|
554 |
+
|
555 |
+
init_rng = self.rng
|
556 |
+
for res in range(2, resolution_log2 + 1):
|
557 |
+
init_rng, init_key = random.split(init_rng)
|
558 |
+
x, y = SynthesisBlock(fmaps=nf(res - 1),
|
559 |
+
res=res,
|
560 |
+
num_layers=1 if res == 2 else 2,
|
561 |
+
num_channels=self.num_channels,
|
562 |
+
activation=self.activation,
|
563 |
+
use_noise=self.use_noise,
|
564 |
+
resample_kernel=self.resample_kernel,
|
565 |
+
fused_modconv=self.fused_modconv,
|
566 |
+
param_dict=self.param_dict_[f'block_{2 ** res}x{2 ** res}'] if self.param_dict_ is not None else None,
|
567 |
+
clip_conv=self.clip_conv,
|
568 |
+
dtype=self.dtype if res > resolution_log2 - self.num_fp16_res else 'float32',
|
569 |
+
rng=init_key)(x, y, dlatents_in, noise_mode, rng)
|
570 |
+
|
571 |
+
return y
|
572 |
+
|
573 |
+
|
574 |
+
class Generator(nn.Module):
|
575 |
+
"""
|
576 |
+
Generator.
|
577 |
+
|
578 |
+
Attributes:
|
579 |
+
resolution (int): Output resolution.
|
580 |
+
num_channels (int): Number of output color channels.
|
581 |
+
z_dim (int): Input latent (Z) dimensionality.
|
582 |
+
c_dim (int): Conditioning label (C) dimensionality, 0 = no label.
|
583 |
+
w_dim (int): Intermediate latent (W) dimensionality.
|
584 |
+
mapping_layer_features (int): Number of intermediate features in the mapping layers, None = same as w_dim.
|
585 |
+
mapping_embed_features (int): Label embedding dimensionality, None = same as w_dim.
|
586 |
+
num_ws (int): Number of intermediate latents to output, None = do not broadcast.
|
587 |
+
num_mapping_layers (int): Number of mapping layers.
|
588 |
+
fmap_base (int): Overall multiplier for the number of feature maps.
|
589 |
+
fmap_decay (int): Log2 feature map reduction when doubling the resolution.
|
590 |
+
fmap_min (int): Minimum number of feature maps in any layer.
|
591 |
+
fmap_max (int): Maximum number of feature maps in any layer.
|
592 |
+
fmap_const (int): Number of feature maps in the constant input layer. None = default.
|
593 |
+
pretrained (str): Which pretrained model to use, None for random initialization.
|
594 |
+
ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used.
|
595 |
+
use_noise (bool): If True, add spatial-specific noise.
|
596 |
+
activation (str): Activation function: 'relu', 'lrelu', etc.
|
597 |
+
w_avg_beta (float): Decay for tracking the moving average of W during training, None = do not track.
|
598 |
+
mapping_lr_multiplier (float): Learning rate multiplier for the mapping network.
|
599 |
+
resample_kernel (Tuple): Kernel that is used for FIR filter.
|
600 |
+
fused_modconv (bool): If True, Perform modulation, convolution, and demodulation as a single fused operation.
|
601 |
+
num_fp16_res (int): Use float16 for the 'num_fp16_res' highest resolutions.
|
602 |
+
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
|
603 |
+
dtype (str): Data type.
|
604 |
+
rng (jax.random.PRNGKey): PRNG for initialization.
|
605 |
+
"""
|
606 |
+
# Dimensionality
|
607 |
+
resolution: int=1024
|
608 |
+
num_channels: int=3
|
609 |
+
z_dim: int=512
|
610 |
+
c_dim: int=0
|
611 |
+
w_dim: int=512
|
612 |
+
mapping_layer_features: int=512
|
613 |
+
mapping_embed_features: int=None
|
614 |
+
|
615 |
+
# Layers
|
616 |
+
num_ws: int=18
|
617 |
+
num_mapping_layers: int=8
|
618 |
+
|
619 |
+
# Capacity
|
620 |
+
fmap_base: int=16384
|
621 |
+
fmap_decay: int=1
|
622 |
+
fmap_min: int=1
|
623 |
+
fmap_max: int=512
|
624 |
+
fmap_const: int=None
|
625 |
+
|
626 |
+
# Pretraining
|
627 |
+
pretrained: str=None
|
628 |
+
ckpt_dir: str=None
|
629 |
+
|
630 |
+
# Internal details
|
631 |
+
use_noise: bool=True
|
632 |
+
activation: str='leaky_relu'
|
633 |
+
w_avg_beta: float=0.995
|
634 |
+
mapping_lr_multiplier: float=0.01
|
635 |
+
resample_kernel: Tuple=(1, 3, 3, 1)
|
636 |
+
fused_modconv: bool=False
|
637 |
+
num_fp16_res: int=0
|
638 |
+
clip_conv: float=None
|
639 |
+
dtype: str='float32'
|
640 |
+
rng: Any=random.PRNGKey(0)
|
641 |
+
|
642 |
+
def setup(self):
|
643 |
+
self.resolution_ = self.resolution
|
644 |
+
self.c_dim_ = self.c_dim
|
645 |
+
self.num_mapping_layers_ = self.num_mapping_layers
|
646 |
+
if self.pretrained is not None:
|
647 |
+
assert self.pretrained in URLS.keys(), f'Pretrained model not available: {self.pretrained}'
|
648 |
+
ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained])
|
649 |
+
self.param_dict = h5py.File(ckpt_file, 'r')
|
650 |
+
self.resolution_ = RESOLUTION[self.pretrained]
|
651 |
+
self.c_dim_ = C_DIM[self.pretrained]
|
652 |
+
self.num_mapping_layers_ = NUM_MAPPING_LAYERS[self.pretrained]
|
653 |
+
else:
|
654 |
+
self.param_dict = None
|
655 |
+
self.init_rng_mapping, self.init_rng_synthesis = random.split(self.rng)
|
656 |
+
|
657 |
+
@nn.compact
|
658 |
+
def __call__(self, z, c=None, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False, train=True, noise_mode='random', rng=random.PRNGKey(0)):
|
659 |
+
"""
|
660 |
+
Run Generator.
|
661 |
+
|
662 |
+
Args:
|
663 |
+
z (tensor): Input noise, shape [N, z_dim].
|
664 |
+
c (tensor): Input labels, shape [N, c_dim].
|
665 |
+
truncation_psi (float): Controls truncation (trading off variation for quality). If 1, truncation is disabled.
|
666 |
+
truncation_cutoff (int): Controls truncation. None = disable.
|
667 |
+
skip_w_avg_update (bool): If True, updates the exponential moving average of W.
|
668 |
+
train (bool): Training mode.
|
669 |
+
noise_mode (str): Noise type.
|
670 |
+
- 'const': Constant noise.
|
671 |
+
- 'random': Random noise.
|
672 |
+
- 'none': No noise.
|
673 |
+
rng (jax.random.PRNGKey): PRNG for spatialwise noise.
|
674 |
+
|
675 |
+
Returns:
|
676 |
+
(tensor): Image of shape [N, H, W, num_channels].
|
677 |
+
"""
|
678 |
+
dlatents_in = MappingNetwork(z_dim=self.z_dim,
|
679 |
+
c_dim=self.c_dim_,
|
680 |
+
w_dim=self.w_dim,
|
681 |
+
num_ws=self.num_ws,
|
682 |
+
num_layers=self.num_mapping_layers_,
|
683 |
+
embed_features=self.mapping_embed_features,
|
684 |
+
layer_features=self.mapping_layer_features,
|
685 |
+
activation=self.activation,
|
686 |
+
lr_multiplier=self.mapping_lr_multiplier,
|
687 |
+
w_avg_beta=self.w_avg_beta,
|
688 |
+
param_dict=self.param_dict['mapping_network'] if self.param_dict is not None else None,
|
689 |
+
dtype=self.dtype,
|
690 |
+
rng=self.init_rng_mapping,
|
691 |
+
name='mapping_network')(z, c, truncation_psi, truncation_cutoff, skip_w_avg_update, train)
|
692 |
+
|
693 |
+
x = SynthesisNetwork(resolution=self.resolution_,
|
694 |
+
num_channels=self.num_channels,
|
695 |
+
w_dim=self.w_dim,
|
696 |
+
fmap_base=self.fmap_base,
|
697 |
+
fmap_decay=self.fmap_decay,
|
698 |
+
fmap_min=self.fmap_min,
|
699 |
+
fmap_max=self.fmap_max,
|
700 |
+
fmap_const=self.fmap_const,
|
701 |
+
param_dict=self.param_dict['synthesis_network'] if self.param_dict is not None else None,
|
702 |
+
activation=self.activation,
|
703 |
+
use_noise=self.use_noise,
|
704 |
+
resample_kernel=self.resample_kernel,
|
705 |
+
fused_modconv=self.fused_modconv,
|
706 |
+
num_fp16_res=self.num_fp16_res,
|
707 |
+
clip_conv=self.clip_conv,
|
708 |
+
dtype=self.dtype,
|
709 |
+
rng=self.init_rng_synthesis,
|
710 |
+
name='synthesis_network')(dlatents_in, noise_mode, rng)
|
711 |
+
|
712 |
+
return x
|
713 |
+
|
stylegan2/ops.py
ADDED
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jax
|
2 |
+
import jax.numpy as jnp
|
3 |
+
from jax import random
|
4 |
+
import flax.linen as nn
|
5 |
+
from jax import jit
|
6 |
+
import numpy as np
|
7 |
+
from functools import partial
|
8 |
+
from typing import Any
|
9 |
+
import h5py
|
10 |
+
|
11 |
+
|
12 |
+
#------------------------------------------------------
|
13 |
+
# Other
|
14 |
+
#------------------------------------------------------
|
15 |
+
def minibatch_stddev_layer(x, group_size=None, num_new_features=1):
|
16 |
+
if group_size is None:
|
17 |
+
group_size = x.shape[0]
|
18 |
+
else:
|
19 |
+
# Minibatch must be divisible by (or smaller than) group_size.
|
20 |
+
group_size = min(group_size, x.shape[0])
|
21 |
+
|
22 |
+
G = group_size
|
23 |
+
F = num_new_features
|
24 |
+
_, H, W, C = x.shape
|
25 |
+
c = C // F
|
26 |
+
|
27 |
+
# [NHWC] Cast to FP32.
|
28 |
+
y = x.astype(jnp.float32)
|
29 |
+
# [GnHWFc] Split minibatch N into n groups of size G, and channels C into F groups of size c.
|
30 |
+
y = jnp.reshape(y, newshape=(G, -1, H, W, F, c))
|
31 |
+
# [GnHWFc] Subtract mean over group.
|
32 |
+
y -= jnp.mean(y, axis=0)
|
33 |
+
# [nHWFc] Calc variance over group.
|
34 |
+
y = jnp.mean(jnp.square(y), axis=0)
|
35 |
+
# [nHWFc] Calc stddev over group.
|
36 |
+
y = jnp.sqrt(y + 1e-8)
|
37 |
+
# [nF] Take average over channels and pixels.
|
38 |
+
y = jnp.mean(y, axis=(1, 2, 4))
|
39 |
+
# [nF] Cast back to original data type.
|
40 |
+
y = y.astype(x.dtype)
|
41 |
+
# [n11F] Add missing dimensions.
|
42 |
+
y = jnp.reshape(y, newshape=(-1, 1, 1, F))
|
43 |
+
# [NHWC] Replicate over group and pixels.
|
44 |
+
y = jnp.tile(y, (G, H, W, 1))
|
45 |
+
return jnp.concatenate((x, y), axis=3)
|
46 |
+
|
47 |
+
|
48 |
+
#------------------------------------------------------
|
49 |
+
# Activation
|
50 |
+
#------------------------------------------------------
|
51 |
+
def apply_activation(x, activation='linear', alpha=0.2, gain=np.sqrt(2)):
|
52 |
+
gain = jnp.array(gain, dtype=x.dtype)
|
53 |
+
if activation == 'relu':
|
54 |
+
return jax.nn.relu(x) * gain
|
55 |
+
if activation == 'leaky_relu':
|
56 |
+
return jax.nn.leaky_relu(x, negative_slope=alpha) * gain
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
#------------------------------------------------------
|
61 |
+
# Weights
|
62 |
+
#------------------------------------------------------
|
63 |
+
def get_weight(shape, lr_multiplier=1, bias=True, param_dict=None, layer_name='', key=None):
|
64 |
+
if param_dict is None:
|
65 |
+
w = random.normal(key, shape=shape, dtype=jnp.float32) / lr_multiplier
|
66 |
+
if bias: b = jnp.zeros(shape=(shape[-1],), dtype=jnp.float32)
|
67 |
+
else:
|
68 |
+
w = jnp.array(param_dict[layer_name]['weight']).astype(jnp.float32)
|
69 |
+
if bias: b = jnp.array(param_dict[layer_name]['bias']).astype(jnp.float32)
|
70 |
+
|
71 |
+
if bias: return w, b
|
72 |
+
return w
|
73 |
+
|
74 |
+
|
75 |
+
def equalize_lr_weight(w, lr_multiplier=1):
|
76 |
+
"""
|
77 |
+
Equalized learning rate, see: https://arxiv.org/pdf/1710.10196.pdf.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
w (tensor): Weight parameter. Shape [kernel, kernel, fmaps_in, fmaps_out]
|
81 |
+
for convolutions and shape [in, out] for MLPs.
|
82 |
+
lr_multiplier (float): Learning rate multiplier.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
(tensor): Scaled weight parameter.
|
86 |
+
"""
|
87 |
+
in_features = np.prod(w.shape[:-1])
|
88 |
+
gain = lr_multiplier / np.sqrt(in_features)
|
89 |
+
w *= gain
|
90 |
+
return w
|
91 |
+
|
92 |
+
|
93 |
+
def equalize_lr_bias(b, lr_multiplier=1):
|
94 |
+
"""
|
95 |
+
Equalized learning rate, see: https://arxiv.org/pdf/1710.10196.pdf.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
b (tensor): Bias parameter.
|
99 |
+
lr_multiplier (float): Learning rate multiplier.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
(tensor): Scaled bias parameter.
|
103 |
+
"""
|
104 |
+
gain = lr_multiplier
|
105 |
+
b *= gain
|
106 |
+
return b
|
107 |
+
|
108 |
+
|
109 |
+
#------------------------------------------------------
|
110 |
+
# Normalization
|
111 |
+
#------------------------------------------------------
|
112 |
+
def normalize_2nd_moment(x, eps=1e-8):
|
113 |
+
return x * jax.lax.rsqrt(jnp.mean(jnp.square(x), axis=1, keepdims=True) + eps)
|
114 |
+
|
115 |
+
|
116 |
+
#------------------------------------------------------
|
117 |
+
# Upsampling
|
118 |
+
#------------------------------------------------------
|
119 |
+
def setup_filter(f, normalize=True, flip_filter=False, gain=1, separable=None):
|
120 |
+
"""
|
121 |
+
Convenience function to setup 2D FIR filter for `upfirdn2d()`.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
f (tensor): Tensor or python list of the shape.
|
125 |
+
normalize (bool): Normalize the filter so that it retains the magnitude.
|
126 |
+
for constant input signal (DC)? (default: True).
|
127 |
+
flip_filter (bool): Flip the filter? (default: False).
|
128 |
+
gain (int): Overall scaling factor for signal magnitude (default: 1).
|
129 |
+
separable: Return a separable filter? (default: select automatically).
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
(tensor): Output filter of shape [filter_height, filter_width] or [filter_taps]
|
133 |
+
"""
|
134 |
+
# Validate.
|
135 |
+
if f is None:
|
136 |
+
f = 1
|
137 |
+
f = jnp.array(f, dtype=jnp.float32)
|
138 |
+
assert f.ndim in [0, 1, 2]
|
139 |
+
assert f.size > 0
|
140 |
+
if f.ndim == 0:
|
141 |
+
f = f[jnp.newaxis]
|
142 |
+
|
143 |
+
# Separable?
|
144 |
+
if separable is None:
|
145 |
+
separable = (f.ndim == 1 and f.size >= 8)
|
146 |
+
if f.ndim == 1 and not separable:
|
147 |
+
f = jnp.outer(f, f)
|
148 |
+
assert f.ndim == (1 if separable else 2)
|
149 |
+
|
150 |
+
# Apply normalize, flip, gain, and device.
|
151 |
+
if normalize:
|
152 |
+
f /= jnp.sum(f)
|
153 |
+
if flip_filter:
|
154 |
+
for i in range(f.ndim):
|
155 |
+
f = jnp.flip(f, axis=i)
|
156 |
+
f = f * (gain ** (f.ndim / 2))
|
157 |
+
return f
|
158 |
+
|
159 |
+
|
160 |
+
def upfirdn2d(x, f, padding=(2, 1, 2, 1), up=1, down=1, strides=(1, 1), flip_filter=False, gain=1):
|
161 |
+
|
162 |
+
if f is None:
|
163 |
+
f = jnp.ones((1, 1), dtype=jnp.float32)
|
164 |
+
|
165 |
+
B, H, W, C = x.shape
|
166 |
+
padx0, padx1, pady0, pady1 = padding
|
167 |
+
|
168 |
+
# upsample by inserting zeros
|
169 |
+
x = jnp.reshape(x, newshape=(B, H, 1, W, 1, C))
|
170 |
+
x = jnp.pad(x, pad_width=((0, 0), (0, 0), (0, up - 1), (0, 0), (0, up - 1), (0, 0)))
|
171 |
+
x = jnp.reshape(x, newshape=(B, H * up, W * up, C))
|
172 |
+
|
173 |
+
# padding
|
174 |
+
x = jnp.pad(x, pad_width=((0, 0), (max(pady0, 0), max(pady1, 0)), (max(padx0, 0), max(padx1, 0)), (0, 0)))
|
175 |
+
x = x[:, max(-pady0, 0) : x.shape[1] - max(-pady1, 0), max(-padx0, 0) : x.shape[2] - max(-padx1, 0)]
|
176 |
+
|
177 |
+
# setup filter
|
178 |
+
f = f * (gain ** (f.ndim / 2))
|
179 |
+
if not flip_filter:
|
180 |
+
for i in range(f.ndim):
|
181 |
+
f = jnp.flip(f, axis=i)
|
182 |
+
|
183 |
+
# convole filter
|
184 |
+
f = jnp.repeat(jnp.expand_dims(f, axis=(-2, -1)), repeats=C, axis=-1)
|
185 |
+
if f.ndim == 4:
|
186 |
+
x = jax.lax.conv_general_dilated(x,
|
187 |
+
f.astype(x.dtype),
|
188 |
+
window_strides=strides or (1,) * (x.ndim - 2),
|
189 |
+
padding='valid',
|
190 |
+
dimension_numbers=nn.linear._conv_dimension_numbers(x.shape),
|
191 |
+
feature_group_count=C)
|
192 |
+
else:
|
193 |
+
x = jax.lax.conv_general_dilated(x,
|
194 |
+
jnp.expand_dims(f, axis=0).astype(x.dtype),
|
195 |
+
window_strides=strides or (1,) * (x.ndim - 2),
|
196 |
+
padding='valid',
|
197 |
+
dimension_numbers=nn.linear._conv_dimension_numbers(x.shape),
|
198 |
+
feature_group_count=C)
|
199 |
+
x = jax.lax.conv_general_dilated(x,
|
200 |
+
jnp.expand_dims(f, axis=1).astype(x.dtype),
|
201 |
+
window_strides=strides or (1,) * (x.ndim - 2),
|
202 |
+
padding='valid',
|
203 |
+
dimension_numbers=nn.linear._conv_dimension_numbers(x.shape),
|
204 |
+
feature_group_count=C)
|
205 |
+
x = x[:, ::down, ::down]
|
206 |
+
return x
|
207 |
+
|
208 |
+
|
209 |
+
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1):
|
210 |
+
if f.ndim == 1:
|
211 |
+
fh, fw = f.shape[0], f.shape[0]
|
212 |
+
elif f.ndim == 2:
|
213 |
+
fh, fw = f.shape[0], f.shape[1]
|
214 |
+
else:
|
215 |
+
raise ValueError('Invalid filter shape:', f.shape)
|
216 |
+
padx0 = padding + (fw + up - 1) // 2
|
217 |
+
padx1 = padding + (fw - up) // 2
|
218 |
+
pady0 = padding + (fh + up - 1) // 2
|
219 |
+
pady1 = padding + (fh - up) // 2
|
220 |
+
return upfirdn2d(x, f=f, up=up, padding=(padx0, padx1, pady0, pady1), flip_filter=flip_filter, gain=gain * up * up)
|
221 |
+
|
222 |
+
|
223 |
+
#------------------------------------------------------
|
224 |
+
# Linear
|
225 |
+
#------------------------------------------------------
|
226 |
+
class LinearLayer(nn.Module):
|
227 |
+
"""
|
228 |
+
Linear Layer.
|
229 |
+
|
230 |
+
Attributes:
|
231 |
+
in_features (int): Input dimension.
|
232 |
+
out_features (int): Output dimension.
|
233 |
+
use_bias (bool): If True, use bias.
|
234 |
+
bias_init (int): Bias init.
|
235 |
+
lr_multiplier (float): Learning rate multiplier.
|
236 |
+
activation (str): Activation function: 'relu', 'lrelu', etc.
|
237 |
+
param_dict (h5py.Group): Parameter dict with pretrained parameters.
|
238 |
+
layer_name (str): Layer name.
|
239 |
+
dtype (str): Data type.
|
240 |
+
rng (jax.random.PRNGKey): Random seed for initialization.
|
241 |
+
"""
|
242 |
+
in_features: int
|
243 |
+
out_features: int
|
244 |
+
use_bias: bool=True
|
245 |
+
bias_init: int=0
|
246 |
+
lr_multiplier: float=1
|
247 |
+
activation: str='linear'
|
248 |
+
param_dict: h5py.Group=None
|
249 |
+
layer_name: str=None
|
250 |
+
dtype: str='float32'
|
251 |
+
rng: Any=random.PRNGKey(0)
|
252 |
+
|
253 |
+
@nn.compact
|
254 |
+
def __call__(self, x):
|
255 |
+
"""
|
256 |
+
Run Linear Layer.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
x (tensor): Input tensor of shape [N, in_features].
|
260 |
+
|
261 |
+
Returns:
|
262 |
+
(tensor): Output tensor of shape [N, out_features].
|
263 |
+
"""
|
264 |
+
w_shape = [self.in_features, self.out_features]
|
265 |
+
params = get_weight(w_shape, self.lr_multiplier, self.use_bias, self.param_dict, self.layer_name, self.rng)
|
266 |
+
|
267 |
+
if self.use_bias:
|
268 |
+
w, b = params
|
269 |
+
else:
|
270 |
+
w = params
|
271 |
+
|
272 |
+
w = self.param(name='weight', init_fn=lambda *_ : w)
|
273 |
+
w = equalize_lr_weight(w, self.lr_multiplier)
|
274 |
+
x = jnp.matmul(x, w.astype(x.dtype))
|
275 |
+
|
276 |
+
if self.use_bias:
|
277 |
+
b = self.param(name='bias', init_fn=lambda *_ : b)
|
278 |
+
b = equalize_lr_bias(b, self.lr_multiplier)
|
279 |
+
x += b.astype(x.dtype)
|
280 |
+
x += self.bias_init
|
281 |
+
|
282 |
+
x = apply_activation(x, activation=self.activation)
|
283 |
+
return x
|
284 |
+
|
285 |
+
|
286 |
+
#------------------------------------------------------
|
287 |
+
# Convolution
|
288 |
+
#------------------------------------------------------
|
289 |
+
def conv_downsample_2d(x, w, k=None, factor=2, gain=1, padding=0):
|
290 |
+
"""
|
291 |
+
Fused downsample convolution.
|
292 |
+
|
293 |
+
Padding is performed only once at the beginning, not between the operations.
|
294 |
+
The fused op is considerably more efficient than performing the same calculation
|
295 |
+
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
x (tensor): Input tensor of the shape [N, H, W, C].
|
299 |
+
w (tensor): Weight tensor of the shape [filterH, filterW, inChannels, outChannels].
|
300 |
+
Grouped convolution can be performed by inChannels = x.shape[0] // numGroups.
|
301 |
+
k (tensor): FIR filter of the shape [firH, firW] or [firN].
|
302 |
+
The default is `[1] * factor`, which corresponds to average pooling.
|
303 |
+
factor (int): Downsampling factor (default: 2).
|
304 |
+
gain (float): Scaling factor for signal magnitude (default: 1.0).
|
305 |
+
padding (int): Number of pixels to pad or crop the output on each side (default: 0).
|
306 |
+
|
307 |
+
Returns:
|
308 |
+
(tensor): Output of the shape [N, H // factor, W // factor, C].
|
309 |
+
"""
|
310 |
+
assert isinstance(factor, int) and factor >= 1
|
311 |
+
assert isinstance(padding, int)
|
312 |
+
|
313 |
+
# Check weight shape.
|
314 |
+
ch, cw, _inC, _outC = w.shape
|
315 |
+
assert cw == ch
|
316 |
+
|
317 |
+
# Setup filter kernel.
|
318 |
+
k = setup_filter(k, gain=gain)
|
319 |
+
assert k.shape[0] == k.shape[1]
|
320 |
+
|
321 |
+
# Execute.
|
322 |
+
pad0 = (k.shape[0] - factor + cw) // 2 + padding * factor
|
323 |
+
pad1 = (k.shape[0] - factor + cw - 1) // 2 + padding * factor
|
324 |
+
x = upfirdn2d(x=x, f=k, padding=(pad0, pad0, pad1, pad1))
|
325 |
+
|
326 |
+
x = jax.lax.conv_general_dilated(x,
|
327 |
+
w,
|
328 |
+
window_strides=(factor, factor),
|
329 |
+
padding='VALID',
|
330 |
+
dimension_numbers=nn.linear._conv_dimension_numbers(x.shape))
|
331 |
+
return x
|
332 |
+
|
333 |
+
|
334 |
+
def upsample_conv_2d(x, w, k=None, factor=2, gain=1, padding=0):
|
335 |
+
"""
|
336 |
+
Fused upsample convolution.
|
337 |
+
|
338 |
+
Padding is performed only once at the beginning, not between the operations.
|
339 |
+
The fused op is considerably more efficient than performing the same calculation
|
340 |
+
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
x (tensor): Input tensor of the shape [N, H, W, C].
|
344 |
+
w (tensor): Weight tensor of the shape [filterH, filterW, inChannels, outChannels].
|
345 |
+
Grouped convolution can be performed by inChannels = x.shape[0] // numGroups.
|
346 |
+
k (tensor): FIR filter of the shape [firH, firW] or [firN].
|
347 |
+
The default is [1] * factor, which corresponds to nearest-neighbor upsampling.
|
348 |
+
factor (int): Integer upsampling factor (default: 2).
|
349 |
+
gain (float): Scaling factor for signal magnitude (default: 1.0).
|
350 |
+
padding (int): Number of pixels to pad or crop the output on each side (default: 0).
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
(tensor): Output of the shape [N, H * factor, W * factor, C].
|
354 |
+
"""
|
355 |
+
assert isinstance(factor, int) and factor >= 1
|
356 |
+
assert isinstance(padding, int)
|
357 |
+
|
358 |
+
# Check weight shape.
|
359 |
+
ch, cw, _inC, _outC = w.shape
|
360 |
+
inC = w.shape[2]
|
361 |
+
outC = w.shape[3]
|
362 |
+
assert cw == ch
|
363 |
+
|
364 |
+
# Fast path for 1x1 convolution.
|
365 |
+
if cw == 1 and ch == 1:
|
366 |
+
x = jax.lax.conv_general_dilated(x,
|
367 |
+
w,
|
368 |
+
window_strides=(1, 1),
|
369 |
+
padding='VALID',
|
370 |
+
dimension_numbers=nn.linear._conv_dimension_numbers(x.shape))
|
371 |
+
k = setup_filter(k, gain=gain * (factor ** 2))
|
372 |
+
pad0 = (k.shape[0] + factor - cw) // 2 + padding
|
373 |
+
pad1 = (k.shape[0] - factor) // 2 + padding
|
374 |
+
x = upfirdn2d(x, f=k, up=factor, padding=(pad0, pad1, pad0, pad1))
|
375 |
+
return x
|
376 |
+
|
377 |
+
# Setup filter kernel.
|
378 |
+
k = setup_filter(k, gain=gain * (factor ** 2))
|
379 |
+
assert k.shape[0] == k.shape[1]
|
380 |
+
|
381 |
+
# Determine data dimensions.
|
382 |
+
stride = (factor, factor)
|
383 |
+
output_shape = ((x.shape[1] - 1) * factor + ch, (x.shape[2] - 1) * factor + cw)
|
384 |
+
num_groups = x.shape[3] // inC
|
385 |
+
|
386 |
+
# Transpose weights.
|
387 |
+
w = jnp.reshape(w, (ch, cw, inC, num_groups, -1))
|
388 |
+
w = jnp.transpose(w[::-1, ::-1], (0, 1, 4, 3, 2))
|
389 |
+
w = jnp.reshape(w, (ch, cw, -1, num_groups * inC))
|
390 |
+
|
391 |
+
# Execute.
|
392 |
+
x = gradient_based_conv_transpose(lhs=x,
|
393 |
+
rhs=w,
|
394 |
+
strides=stride,
|
395 |
+
padding='VALID',
|
396 |
+
output_padding=(0, 0, 0, 0),
|
397 |
+
output_shape=output_shape,
|
398 |
+
)
|
399 |
+
|
400 |
+
pad0 = (k.shape[0] + factor - cw) // 2 + padding
|
401 |
+
pad1 = (k.shape[0] - factor - cw + 3) // 2 + padding
|
402 |
+
x = upfirdn2d(x=x, f=k, padding=(pad0, pad1, pad0, pad1))
|
403 |
+
return x
|
404 |
+
|
405 |
+
|
406 |
+
def conv2d(x, w, up=False, down=False, resample_kernel=None, padding=0):
|
407 |
+
assert not (up and down)
|
408 |
+
kernel = w.shape[0]
|
409 |
+
assert w.shape[1] == kernel
|
410 |
+
assert kernel >= 1 and kernel % 2 == 1
|
411 |
+
|
412 |
+
num_groups = x.shape[3] // w.shape[2]
|
413 |
+
|
414 |
+
w = w.astype(x.dtype)
|
415 |
+
if up:
|
416 |
+
x = upsample_conv_2d(x, w, k=resample_kernel, padding=padding)
|
417 |
+
elif down:
|
418 |
+
x = conv_downsample_2d(x, w, k=resample_kernel, padding=padding)
|
419 |
+
else:
|
420 |
+
padding_mode = {0: 'SAME', -(kernel // 2): 'VALID'}[padding]
|
421 |
+
x = jax.lax.conv_general_dilated(x,
|
422 |
+
w,
|
423 |
+
window_strides=(1, 1),
|
424 |
+
padding=padding_mode,
|
425 |
+
dimension_numbers=nn.linear._conv_dimension_numbers(x.shape),
|
426 |
+
feature_group_count=num_groups)
|
427 |
+
return x
|
428 |
+
|
429 |
+
|
430 |
+
def modulated_conv2d_layer(x, w, s, fmaps, kernel, up=False, down=False, demodulate=True, resample_kernel=None, fused_modconv=False):
|
431 |
+
assert not (up and down)
|
432 |
+
assert kernel >= 1 and kernel % 2 == 1
|
433 |
+
|
434 |
+
# Get weight.
|
435 |
+
wshape = (kernel, kernel, x.shape[3], fmaps)
|
436 |
+
if x.dtype.name == 'float16' and not fused_modconv and demodulate:
|
437 |
+
w *= jnp.sqrt(1 / np.prod(wshape[:-1])) / jnp.max(jnp.abs(w), axis=(0, 1, 2)) # Pre-normalize to avoid float16 overflow.
|
438 |
+
ww = w[jnp.newaxis] # [BkkIO] Introduce minibatch dimension.
|
439 |
+
|
440 |
+
# Modulate.
|
441 |
+
if x.dtype.name == 'float16' and not fused_modconv and demodulate:
|
442 |
+
s *= 1 / jnp.max(jnp.abs(s)) # Pre-normalize to avoid float16 overflow.
|
443 |
+
ww *= s[:, jnp.newaxis, jnp.newaxis, :, jnp.newaxis].astype(w.dtype) # [BkkIO] Scale input feature maps.
|
444 |
+
|
445 |
+
# Demodulate.
|
446 |
+
if demodulate:
|
447 |
+
d = jax.lax.rsqrt(jnp.sum(jnp.square(ww), axis=(1, 2, 3)) + 1e-8) # [BO] Scaling factor.
|
448 |
+
ww *= d[:, jnp.newaxis, jnp.newaxis, jnp.newaxis, :] # [BkkIO] Scale output feature maps.
|
449 |
+
|
450 |
+
# Reshape/scale input.
|
451 |
+
if fused_modconv:
|
452 |
+
x = jnp.transpose(x, axes=(0, 3, 1, 2))
|
453 |
+
x = jnp.reshape(x, (1, -1, x.shape[2], x.shape[3])) # Fused => reshape minibatch to convolution groups.
|
454 |
+
x = jnp.transpose(x, axes=(0, 2, 3, 1))
|
455 |
+
w = jnp.reshape(jnp.transpose(ww, (1, 2, 3, 0, 4)), (ww.shape[1], ww.shape[2], ww.shape[3], -1))
|
456 |
+
else:
|
457 |
+
x *= s[:, jnp.newaxis, jnp.newaxis].astype(x.dtype) # [BIhw] Not fused => scale input activations.
|
458 |
+
|
459 |
+
# 2D convolution.
|
460 |
+
x = conv2d(x, w.astype(x.dtype), up=up, down=down, resample_kernel=resample_kernel)
|
461 |
+
|
462 |
+
# Reshape/scale output.
|
463 |
+
if fused_modconv:
|
464 |
+
x = jnp.transpose(x, axes=(0, 3, 1, 2))
|
465 |
+
x = jnp.reshape(x, (-1, fmaps, x.shape[2], x.shape[3])) # Fused => reshape convolution groups back to minibatch.
|
466 |
+
x = jnp.transpose(x, axes=(0, 2, 3, 1))
|
467 |
+
elif demodulate:
|
468 |
+
x *= d[:, jnp.newaxis, jnp.newaxis].astype(x.dtype) # [BOhw] Not fused => scale output activations.
|
469 |
+
|
470 |
+
return x
|
471 |
+
|
472 |
+
|
473 |
+
def _deconv_output_length(input_length, filter_size, padding, output_padding=None, stride=0, dilation=1):
|
474 |
+
"""
|
475 |
+
Taken from: https://github.com/google/jax/pull/5772/commits
|
476 |
+
|
477 |
+
Determines the output length of a transposed convolution given the input length.
|
478 |
+
Function modified from Keras.
|
479 |
+
Arguments:
|
480 |
+
input_length: Integer.
|
481 |
+
filter_size: Integer.
|
482 |
+
padding: one of `"SAME"`, `"VALID"`, or a 2-integer tuple.
|
483 |
+
output_padding: Integer, amount of padding along the output dimension. Can
|
484 |
+
be set to `None` in which case the output length is inferred.
|
485 |
+
stride: Integer.
|
486 |
+
dilation: Integer.
|
487 |
+
Returns:
|
488 |
+
The output length (integer).
|
489 |
+
"""
|
490 |
+
if input_length is None:
|
491 |
+
return None
|
492 |
+
|
493 |
+
# Get the dilated kernel size
|
494 |
+
filter_size = filter_size + (filter_size - 1) * (dilation - 1)
|
495 |
+
|
496 |
+
# Infer length if output padding is None, else compute the exact length
|
497 |
+
if output_padding is None:
|
498 |
+
if padding == 'VALID':
|
499 |
+
length = input_length * stride + max(filter_size - stride, 0)
|
500 |
+
elif padding == 'SAME':
|
501 |
+
length = input_length * stride
|
502 |
+
else:
|
503 |
+
length = ((input_length - 1) * stride + filter_size - padding[0] - padding[1])
|
504 |
+
|
505 |
+
else:
|
506 |
+
if padding == 'SAME':
|
507 |
+
pad = filter_size // 2
|
508 |
+
total_pad = pad * 2
|
509 |
+
elif padding == 'VALID':
|
510 |
+
total_pad = 0
|
511 |
+
else:
|
512 |
+
total_pad = padding[0] + padding[1]
|
513 |
+
|
514 |
+
length = ((input_length - 1) * stride + filter_size - total_pad + output_padding)
|
515 |
+
return length
|
516 |
+
|
517 |
+
|
518 |
+
def _compute_adjusted_padding(input_size, output_size, kernel_size, stride, padding, dilation=1):
|
519 |
+
"""
|
520 |
+
Taken from: https://github.com/google/jax/pull/5772/commits
|
521 |
+
|
522 |
+
Computes adjusted padding for desired ConvTranspose `output_size`.
|
523 |
+
Ported from DeepMind Haiku.
|
524 |
+
"""
|
525 |
+
kernel_size = (kernel_size - 1) * dilation + 1
|
526 |
+
if padding == 'VALID':
|
527 |
+
expected_input_size = (output_size - kernel_size + stride) // stride
|
528 |
+
if input_size != expected_input_size:
|
529 |
+
raise ValueError(f'The expected input size with the current set of input '
|
530 |
+
f'parameters is {expected_input_size} which doesn\'t '
|
531 |
+
f'match the actual input size {input_size}.')
|
532 |
+
padding_before = 0
|
533 |
+
elif padding == 'SAME':
|
534 |
+
expected_input_size = (output_size + stride - 1) // stride
|
535 |
+
if input_size != expected_input_size:
|
536 |
+
raise ValueError(f'The expected input size with the current set of input '
|
537 |
+
f'parameters is {expected_input_size} which doesn\'t '
|
538 |
+
f'match the actual input size {input_size}.')
|
539 |
+
padding_needed = max(0, (input_size - 1) * stride + kernel_size - output_size)
|
540 |
+
padding_before = padding_needed // 2
|
541 |
+
else:
|
542 |
+
padding_before = padding[0] # type: ignore[assignment]
|
543 |
+
|
544 |
+
expanded_input_size = (input_size - 1) * stride + 1
|
545 |
+
padded_out_size = output_size + kernel_size - 1
|
546 |
+
pad_before = kernel_size - 1 - padding_before
|
547 |
+
pad_after = padded_out_size - expanded_input_size - pad_before
|
548 |
+
return (pad_before, pad_after)
|
549 |
+
|
550 |
+
|
551 |
+
def _flip_axes(x, axes):
|
552 |
+
"""
|
553 |
+
Taken from: https://github.com/google/jax/blob/master/jax/_src/lax/lax.py
|
554 |
+
|
555 |
+
Flip ndarray 'x' along each axis specified in axes tuple.
|
556 |
+
"""
|
557 |
+
for axis in axes:
|
558 |
+
x = jnp.flip(x, axis)
|
559 |
+
return x
|
560 |
+
|
561 |
+
|
562 |
+
def gradient_based_conv_transpose(lhs,
|
563 |
+
rhs,
|
564 |
+
strides,
|
565 |
+
padding,
|
566 |
+
output_padding,
|
567 |
+
output_shape=None,
|
568 |
+
dilation=None,
|
569 |
+
dimension_numbers=None,
|
570 |
+
transpose_kernel=True,
|
571 |
+
feature_group_count=1,
|
572 |
+
precision=None):
|
573 |
+
"""
|
574 |
+
Taken from: https://github.com/google/jax/pull/5772/commits
|
575 |
+
|
576 |
+
Convenience wrapper for calculating the N-d transposed convolution.
|
577 |
+
Much like `conv_transpose`, this function calculates transposed convolutions
|
578 |
+
via fractionally strided convolution rather than calculating the gradient
|
579 |
+
(transpose) of a forward convolution. However, the latter is more common
|
580 |
+
among deep learning frameworks, such as TensorFlow, PyTorch, and Keras.
|
581 |
+
This function provides the same set of APIs to help reproduce results in these frameworks.
|
582 |
+
Args:
|
583 |
+
lhs: a rank `n+2` dimensional input array.
|
584 |
+
rhs: a rank `n+2` dimensional array of kernel weights.
|
585 |
+
strides: sequence of `n` integers, amounts to strides of the corresponding forward convolution.
|
586 |
+
padding: `"SAME"`, `"VALID"`, or a sequence of `n` integer 2-tuples that controls
|
587 |
+
the before-and-after padding for each `n` spatial dimension of
|
588 |
+
the corresponding forward convolution.
|
589 |
+
output_padding: A sequence of integers specifying the amount of padding along
|
590 |
+
each spacial dimension of the output tensor, used to disambiguate the output shape of
|
591 |
+
transposed convolutions when the stride is larger than 1.
|
592 |
+
(see a detailed description at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html)
|
593 |
+
The amount of output padding along a given dimension must
|
594 |
+
be lower than the stride along that same dimension.
|
595 |
+
If set to `None` (default), the output shape is inferred.
|
596 |
+
If both `output_padding` and `output_shape` are specified, they have to be mutually compatible.
|
597 |
+
output_shape: Output shape of the spatial dimensions of a transpose
|
598 |
+
convolution. Can be `None` or an iterable of `n` integers. If a `None` value is given (default),
|
599 |
+
the shape is automatically calculated.
|
600 |
+
Similar to `output_padding`, `output_shape` is also for disambiguating the output shape
|
601 |
+
when stride > 1 (see also
|
602 |
+
https://www.tensorflow.org/api_docs/python/tf/nn/conv2d_transpose)
|
603 |
+
If both `output_padding` and `output_shape` are specified, they have to be mutually compatible.
|
604 |
+
dilation: `None`, or a sequence of `n` integers, giving the
|
605 |
+
dilation factor to apply in each spatial dimension of `rhs`. Dilated convolution
|
606 |
+
is also known as atrous convolution.
|
607 |
+
dimension_numbers: tuple of dimension descriptors as in lax.conv_general_dilated. Defaults to tensorflow convention.
|
608 |
+
transpose_kernel: if `True` flips spatial axes and swaps the input/output
|
609 |
+
channel axes of the kernel. This makes the output of this function identical
|
610 |
+
to the gradient-derived functions like keras.layers.Conv2DTranspose and
|
611 |
+
torch.nn.ConvTranspose2d applied to the same kernel.
|
612 |
+
Although for typical use in neural nets this is unnecessary
|
613 |
+
and makes input/output channel specification confusing, you need to set this to `True`
|
614 |
+
in order to match the behavior in many deep learning frameworks, such as TensorFlow, Keras, and PyTorch.
|
615 |
+
precision: Optional. Either ``None``, which means the default precision for
|
616 |
+
the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
|
617 |
+
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
|
618 |
+
``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
|
619 |
+
Returns:
|
620 |
+
Transposed N-d convolution.
|
621 |
+
"""
|
622 |
+
assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) >= 2
|
623 |
+
ndims = len(lhs.shape)
|
624 |
+
one = (1,) * (ndims - 2)
|
625 |
+
# Set dimensional layout defaults if not specified.
|
626 |
+
if dimension_numbers is None:
|
627 |
+
if ndims == 2:
|
628 |
+
dimension_numbers = ('NC', 'IO', 'NC')
|
629 |
+
elif ndims == 3:
|
630 |
+
dimension_numbers = ('NHC', 'HIO', 'NHC')
|
631 |
+
elif ndims == 4:
|
632 |
+
dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
|
633 |
+
elif ndims == 5:
|
634 |
+
dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC')
|
635 |
+
else:
|
636 |
+
raise ValueError('No 4+ dimensional dimension_number defaults.')
|
637 |
+
dn = jax.lax.conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
|
638 |
+
k_shape = np.take(rhs.shape, dn.rhs_spec)
|
639 |
+
k_sdims = k_shape[2:] # type: ignore[index]
|
640 |
+
i_shape = np.take(lhs.shape, dn.lhs_spec)
|
641 |
+
i_sdims = i_shape[2:] # type: ignore[index]
|
642 |
+
|
643 |
+
# Calculate correct output shape given padding and strides.
|
644 |
+
if dilation is None:
|
645 |
+
dilation = (1,) * (rhs.ndim - 2)
|
646 |
+
|
647 |
+
if output_padding is None:
|
648 |
+
output_padding = [None] * (rhs.ndim - 2) # type: ignore[list-item]
|
649 |
+
|
650 |
+
if isinstance(padding, str):
|
651 |
+
if padding in {'SAME', 'VALID'}:
|
652 |
+
padding = [padding] * (rhs.ndim - 2) # type: ignore[list-item]
|
653 |
+
else:
|
654 |
+
raise ValueError(f"`padding` must be 'VALID' or 'SAME'. Passed: {padding}.")
|
655 |
+
|
656 |
+
inferred_output_shape = tuple(map(_deconv_output_length, i_sdims, k_sdims, padding, output_padding, strides, dilation))
|
657 |
+
|
658 |
+
if output_shape is None:
|
659 |
+
output_shape = inferred_output_shape # type: ignore[assignment]
|
660 |
+
else:
|
661 |
+
if not output_shape == inferred_output_shape:
|
662 |
+
raise ValueError(f'`output_padding` and `output_shape` are not compatible.'
|
663 |
+
f'Inferred output shape from `output_padding`: {inferred_output_shape}, '
|
664 |
+
f'but got `output_shape` {output_shape}')
|
665 |
+
|
666 |
+
pads = tuple(map(_compute_adjusted_padding, i_sdims, output_shape, k_sdims, strides, padding, dilation))
|
667 |
+
|
668 |
+
if transpose_kernel:
|
669 |
+
# flip spatial dims and swap input / output channel axes
|
670 |
+
rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:])
|
671 |
+
rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1])
|
672 |
+
return jax.lax.conv_general_dilated(lhs, rhs, one, pads, strides, dilation, dn, feature_group_count, precision=precision)
|
673 |
+
|
674 |
+
|
stylegan2/utils.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
import requests
|
3 |
+
import os
|
4 |
+
import tempfile
|
5 |
+
|
6 |
+
|
7 |
+
def download(ckpt_dir, url):
|
8 |
+
name = url[url.rfind('/') + 1 : url.rfind('?')]
|
9 |
+
if ckpt_dir is None:
|
10 |
+
ckpt_dir = tempfile.gettempdir()
|
11 |
+
ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels')
|
12 |
+
ckpt_file = os.path.join(ckpt_dir, name)
|
13 |
+
if not os.path.exists(ckpt_file):
|
14 |
+
print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}')
|
15 |
+
if not os.path.exists(ckpt_dir):
|
16 |
+
os.makedirs(ckpt_dir)
|
17 |
+
|
18 |
+
response = requests.get(url, stream=True)
|
19 |
+
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
20 |
+
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
21 |
+
|
22 |
+
# first create temp file, in case the download fails
|
23 |
+
ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp')
|
24 |
+
with open(ckpt_file_temp, 'wb') as file:
|
25 |
+
for data in response.iter_content(chunk_size=1024):
|
26 |
+
progress_bar.update(len(data))
|
27 |
+
file.write(data)
|
28 |
+
progress_bar.close()
|
29 |
+
|
30 |
+
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
31 |
+
print('An error occured while downloading, please try again.')
|
32 |
+
if os.path.exists(ckpt_file_temp):
|
33 |
+
os.remove(ckpt_file_temp)
|
34 |
+
else:
|
35 |
+
# if download was successful, rename the temp file
|
36 |
+
os.rename(ckpt_file_temp, ckpt_file)
|
37 |
+
return ckpt_file
|
training.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jax
|
2 |
+
import jax.numpy as jnp
|
3 |
+
import flax
|
4 |
+
from flax.optim import dynamic_scale as dynamic_scale_lib
|
5 |
+
from flax.core import frozen_dict
|
6 |
+
import optax
|
7 |
+
import numpy as np
|
8 |
+
import functools
|
9 |
+
import wandb
|
10 |
+
import time
|
11 |
+
|
12 |
+
import stylegan2
|
13 |
+
import data_pipeline
|
14 |
+
import checkpoint
|
15 |
+
import training_utils
|
16 |
+
import training_steps
|
17 |
+
from fid import FID
|
18 |
+
|
19 |
+
import logging
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
def tree_shape(item):
|
25 |
+
return jax.tree_map(lambda c: c.shape, item)
|
26 |
+
|
27 |
+
|
28 |
+
def train_and_evaluate(config):
|
29 |
+
num_devices = jax.device_count() # 8
|
30 |
+
num_local_devices = jax.local_device_count() # 4
|
31 |
+
num_workers = jax.process_count()
|
32 |
+
|
33 |
+
# --------------------------------------
|
34 |
+
# Data
|
35 |
+
# --------------------------------------
|
36 |
+
ds_train, dataset_info = data_pipeline.get_data(data_dir=config.data_dir,
|
37 |
+
img_size=config.resolution,
|
38 |
+
img_channels=config.img_channels,
|
39 |
+
num_classes=config.c_dim,
|
40 |
+
num_local_devices=num_local_devices,
|
41 |
+
batch_size=config.batch_size)
|
42 |
+
|
43 |
+
# --------------------------------------
|
44 |
+
# Seeding and Precision
|
45 |
+
# --------------------------------------
|
46 |
+
rng = jax.random.PRNGKey(config.random_seed)
|
47 |
+
|
48 |
+
if config.mixed_precision:
|
49 |
+
dtype = jnp.float16
|
50 |
+
elif config.bf16:
|
51 |
+
dtype = jnp.bfloat16
|
52 |
+
else:
|
53 |
+
dtype = jnp.float32
|
54 |
+
logger.info(f'Running on dtype {dtype}')
|
55 |
+
|
56 |
+
platform = jax.local_devices()[0].platform
|
57 |
+
if config.mixed_precision and platform == 'gpu':
|
58 |
+
dynamic_scale_G_main = dynamic_scale_lib.DynamicScale()
|
59 |
+
dynamic_scale_D_main = dynamic_scale_lib.DynamicScale()
|
60 |
+
dynamic_scale_G_reg = dynamic_scale_lib.DynamicScale()
|
61 |
+
dynamic_scale_D_reg = dynamic_scale_lib.DynamicScale()
|
62 |
+
clip_conv = 256
|
63 |
+
num_fp16_res = 4
|
64 |
+
else:
|
65 |
+
dynamic_scale_G_main = None
|
66 |
+
dynamic_scale_D_main = None
|
67 |
+
dynamic_scale_G_reg = None
|
68 |
+
dynamic_scale_D_reg = None
|
69 |
+
clip_conv = None
|
70 |
+
num_fp16_res = 0
|
71 |
+
|
72 |
+
# --------------------------------------
|
73 |
+
# Initialize Models
|
74 |
+
# --------------------------------------
|
75 |
+
logger.info('Initialize models...')
|
76 |
+
|
77 |
+
rng, init_rng = jax.random.split(rng)
|
78 |
+
|
79 |
+
# Generator initialization for training
|
80 |
+
start_mn = time.time()
|
81 |
+
logger.info("Creating MappingNetwork...")
|
82 |
+
mapping_net = stylegan2.MappingNetwork(z_dim=config.z_dim,
|
83 |
+
c_dim=config.c_dim,
|
84 |
+
w_dim=config.w_dim,
|
85 |
+
num_ws=int(np.log2(config.resolution)) * 2 - 3,
|
86 |
+
num_layers=8,
|
87 |
+
dtype=dtype)
|
88 |
+
|
89 |
+
mapping_net_vars = mapping_net.init(init_rng,
|
90 |
+
jnp.ones((1, config.z_dim)),
|
91 |
+
jnp.ones((1, config.c_dim)))
|
92 |
+
|
93 |
+
mapping_net_params, moving_stats = mapping_net_vars['params'], mapping_net_vars['moving_stats']
|
94 |
+
|
95 |
+
logger.info(f"MappingNetwork took {time.time() - start_mn:.2f}s")
|
96 |
+
|
97 |
+
logger.info("Creating SynthesisNetwork...")
|
98 |
+
start_sn = time.time()
|
99 |
+
synthesis_net = stylegan2.SynthesisNetwork(resolution=config.resolution,
|
100 |
+
num_channels=config.img_channels,
|
101 |
+
w_dim=config.w_dim,
|
102 |
+
fmap_base=config.fmap_base,
|
103 |
+
num_fp16_res=num_fp16_res,
|
104 |
+
clip_conv=clip_conv,
|
105 |
+
dtype=dtype)
|
106 |
+
|
107 |
+
synthesis_net_vars = synthesis_net.init(init_rng,
|
108 |
+
jnp.ones((1, mapping_net.num_ws, config.w_dim)))
|
109 |
+
synthesis_net_params, noise_consts = synthesis_net_vars['params'], synthesis_net_vars['noise_consts']
|
110 |
+
|
111 |
+
logger.info(f"SynthesisNetwork took {time.time() - start_sn:.2f}s")
|
112 |
+
|
113 |
+
params_G = frozen_dict.FrozenDict(
|
114 |
+
{'mapping': mapping_net_params,
|
115 |
+
'synthesis': synthesis_net_params}
|
116 |
+
)
|
117 |
+
|
118 |
+
# Discriminator initialization for training
|
119 |
+
logger.info("Creating Discriminator...")
|
120 |
+
start_d = time.time()
|
121 |
+
discriminator = stylegan2.Discriminator(resolution=config.resolution,
|
122 |
+
num_channels=config.img_channels,
|
123 |
+
c_dim=config.c_dim,
|
124 |
+
mbstd_group_size=config.mbstd_group_size,
|
125 |
+
num_fp16_res=num_fp16_res,
|
126 |
+
clip_conv=clip_conv,
|
127 |
+
dtype=dtype)
|
128 |
+
rng, init_rng = jax.random.split(rng)
|
129 |
+
params_D = discriminator.init(init_rng,
|
130 |
+
jnp.ones((1, config.resolution, config.resolution, config.img_channels)),
|
131 |
+
jnp.ones((1, config.c_dim)))
|
132 |
+
logger.info(f"Discriminator took {time.time() - start_d:.2f}s")
|
133 |
+
|
134 |
+
# Exponential average Generator initialization
|
135 |
+
logger.info("Creating Generator EMA...")
|
136 |
+
start_g = time.time()
|
137 |
+
generator_ema = stylegan2.Generator(resolution=config.resolution,
|
138 |
+
num_channels=config.img_channels,
|
139 |
+
z_dim=config.z_dim,
|
140 |
+
c_dim=config.c_dim,
|
141 |
+
w_dim=config.w_dim,
|
142 |
+
num_ws=int(np.log2(config.resolution)) * 2 - 3,
|
143 |
+
num_mapping_layers=8,
|
144 |
+
fmap_base=config.fmap_base,
|
145 |
+
num_fp16_res=num_fp16_res,
|
146 |
+
clip_conv=clip_conv,
|
147 |
+
dtype=dtype)
|
148 |
+
|
149 |
+
params_ema_G = generator_ema.init(init_rng,
|
150 |
+
jnp.ones((1, config.z_dim)),
|
151 |
+
jnp.ones((1, config.c_dim)))
|
152 |
+
logger.info(f"Took {time.time() - start_g:.2f}s")
|
153 |
+
|
154 |
+
# --------------------------------------
|
155 |
+
# Initialize States and Optimizers
|
156 |
+
# --------------------------------------
|
157 |
+
logger.info('Initialize states...')
|
158 |
+
tx_G = optax.adam(learning_rate=config.learning_rate, b1=0.0, b2=0.99)
|
159 |
+
tx_D = optax.adam(learning_rate=config.learning_rate, b1=0.0, b2=0.99)
|
160 |
+
|
161 |
+
state_G = training_utils.TrainStateG.create(apply_fn=None,
|
162 |
+
apply_mapping=mapping_net.apply,
|
163 |
+
apply_synthesis=synthesis_net.apply,
|
164 |
+
params=params_G,
|
165 |
+
moving_stats=moving_stats,
|
166 |
+
noise_consts=noise_consts,
|
167 |
+
tx=tx_G,
|
168 |
+
dynamic_scale_main=dynamic_scale_G_main,
|
169 |
+
dynamic_scale_reg=dynamic_scale_G_reg,
|
170 |
+
epoch=0)
|
171 |
+
|
172 |
+
state_D = training_utils.TrainStateD.create(apply_fn=discriminator.apply,
|
173 |
+
params=params_D,
|
174 |
+
tx=tx_D,
|
175 |
+
dynamic_scale_main=dynamic_scale_D_main,
|
176 |
+
dynamic_scale_reg=dynamic_scale_D_reg,
|
177 |
+
epoch=0)
|
178 |
+
|
179 |
+
# Copy over the parameters from the training generator to the ema generator
|
180 |
+
params_ema_G = training_utils.update_generator_ema(state_G, params_ema_G, config, ema_beta=0)
|
181 |
+
|
182 |
+
# Running mean of path length for path length regularization
|
183 |
+
pl_mean = jnp.zeros((), dtype=dtype)
|
184 |
+
|
185 |
+
step = 0
|
186 |
+
epoch_offset = 0
|
187 |
+
best_fid_score = np.inf
|
188 |
+
ckpt_path = None
|
189 |
+
|
190 |
+
if config.resume_run_id is not None:
|
191 |
+
# Resume training from existing checkpoint
|
192 |
+
ckpt_path = checkpoint.get_latest_checkpoint(config.ckpt_dir)
|
193 |
+
logger.info(f'Resume training from checkpoint: {ckpt_path}')
|
194 |
+
ckpt = checkpoint.load_checkpoint(ckpt_path)
|
195 |
+
step = ckpt['step']
|
196 |
+
epoch_offset = ckpt['epoch']
|
197 |
+
best_fid_score = ckpt['fid_score']
|
198 |
+
pl_mean = ckpt['pl_mean']
|
199 |
+
state_G = ckpt['state_G']
|
200 |
+
state_D = ckpt['state_D']
|
201 |
+
params_ema_G = ckpt['params_ema_G']
|
202 |
+
config = ckpt['config']
|
203 |
+
elif config.load_from_pkl is not None:
|
204 |
+
# Load checkpoint and start new run
|
205 |
+
ckpt_path = config.load_from_pkl
|
206 |
+
logger.info(f'Load model state from from : {ckpt_path}')
|
207 |
+
ckpt = checkpoint.load_checkpoint(ckpt_path)
|
208 |
+
pl_mean = ckpt['pl_mean']
|
209 |
+
state_G = ckpt['state_G']
|
210 |
+
state_D = ckpt['state_D']
|
211 |
+
params_ema_G = ckpt['params_ema_G']
|
212 |
+
|
213 |
+
# Replicate states across devices
|
214 |
+
pl_mean = flax.jax_utils.replicate(pl_mean)
|
215 |
+
state_G = flax.jax_utils.replicate(state_G)
|
216 |
+
state_D = flax.jax_utils.replicate(state_D)
|
217 |
+
|
218 |
+
# --------------------------------------
|
219 |
+
# Precompile train and eval steps
|
220 |
+
# --------------------------------------
|
221 |
+
logger.info('Precompile training steps...')
|
222 |
+
p_main_step_G = jax.pmap(training_steps.main_step_G, axis_name='batch')
|
223 |
+
p_regul_step_G = jax.pmap(functools.partial(training_steps.regul_step_G, config=config), axis_name='batch')
|
224 |
+
|
225 |
+
p_main_step_D = jax.pmap(training_steps.main_step_D, axis_name='batch')
|
226 |
+
p_regul_step_D = jax.pmap(functools.partial(training_steps.regul_step_D, config=config), axis_name='batch')
|
227 |
+
|
228 |
+
# --------------------------------------
|
229 |
+
# Training
|
230 |
+
# --------------------------------------
|
231 |
+
logger.info('Start training...')
|
232 |
+
fid_metric = FID(generator_ema, ds_train, config)
|
233 |
+
|
234 |
+
# Dict to collect training statistics / losses
|
235 |
+
metrics = {}
|
236 |
+
num_imgs_processed = 0
|
237 |
+
num_steps_per_epoch = dataset_info['num_examples'] // (config.batch_size * num_devices)
|
238 |
+
effective_batch_size = config.batch_size * num_devices
|
239 |
+
if config.wandb and jax.process_index() == 0:
|
240 |
+
# do some more logging
|
241 |
+
wandb.config.effective_batch_size = effective_batch_size
|
242 |
+
wandb.config.num_steps_per_epoch = num_steps_per_epoch
|
243 |
+
wandb.config.num_workers = num_workers
|
244 |
+
wandb.config.device_count = num_devices
|
245 |
+
wandb.config.num_examples = dataset_info['num_examples']
|
246 |
+
wandb.config.vm_name = training_utils.get_vm_name()
|
247 |
+
|
248 |
+
for epoch in range(epoch_offset, config.num_epochs):
|
249 |
+
if config.wandb and jax.process_index() == 0:
|
250 |
+
wandb.log({'training/epochs': epoch}, step=step)
|
251 |
+
|
252 |
+
for batch in data_pipeline.prefetch(ds_train, config.num_prefetch):
|
253 |
+
assert batch['image'].shape[1] == config.batch_size, f"Mismatched batch (batch size: {config.batch_size}, this batch: {batch['image'].shape[1]})"
|
254 |
+
|
255 |
+
# pbar.update(num_devices * config.batch_size)
|
256 |
+
iteration_start_time = time.time()
|
257 |
+
|
258 |
+
if config.c_dim == 0:
|
259 |
+
# No labels in the dataset
|
260 |
+
batch['label'] = None
|
261 |
+
|
262 |
+
# Create two latent noise vectors and combine them for the style mixing regularization
|
263 |
+
rng, key = jax.random.split(rng)
|
264 |
+
z_latent1 = jax.random.normal(key, (num_local_devices, config.batch_size, config.z_dim), dtype)
|
265 |
+
rng, key = jax.random.split(rng)
|
266 |
+
z_latent2 = jax.random.normal(key, (num_local_devices, config.batch_size, config.z_dim), dtype)
|
267 |
+
|
268 |
+
# Split PRNGs across devices
|
269 |
+
rkey = jax.random.split(key, num=num_local_devices)
|
270 |
+
mixing_prob = flax.jax_utils.replicate(config.mixing_prob)
|
271 |
+
|
272 |
+
# --------------------------------------
|
273 |
+
# Update Discriminator
|
274 |
+
# --------------------------------------
|
275 |
+
time_d_start = time.time()
|
276 |
+
state_D, metrics = p_main_step_D(state_G, state_D, batch, z_latent1, z_latent2, metrics, mixing_prob, rkey)
|
277 |
+
time_d_end = time.time()
|
278 |
+
if step % config.D_reg_interval == 0:
|
279 |
+
state_D, metrics = p_regul_step_D(state_D, batch, metrics)
|
280 |
+
|
281 |
+
# --------------------------------------
|
282 |
+
# Update Generator
|
283 |
+
# --------------------------------------
|
284 |
+
time_g_start = time.time()
|
285 |
+
state_G, metrics = p_main_step_G(state_G, state_D, batch, z_latent1, z_latent2, metrics, mixing_prob, rkey)
|
286 |
+
if step % config.G_reg_interval == 0:
|
287 |
+
H, W = batch['image'].shape[-3], batch['image'].shape[-2]
|
288 |
+
rng, key = jax.random.split(rng)
|
289 |
+
pl_noise = jax.random.normal(key, batch['image'].shape, dtype=dtype) / np.sqrt(H * W)
|
290 |
+
state_G, metrics, pl_mean = p_regul_step_G(state_G, batch, z_latent1, pl_noise, pl_mean, metrics,
|
291 |
+
rng=rkey)
|
292 |
+
|
293 |
+
params_ema_G = training_utils.update_generator_ema(flax.jax_utils.unreplicate(state_G),
|
294 |
+
params_ema_G,
|
295 |
+
config)
|
296 |
+
time_g_end = time.time()
|
297 |
+
|
298 |
+
# --------------------------------------
|
299 |
+
# Logging and Checkpointing
|
300 |
+
# --------------------------------------
|
301 |
+
if step % config.save_every == 0 and config.disable_fid:
|
302 |
+
# If FID evaluation is disabled, a checkpoint will be saved every 'save_every' steps.
|
303 |
+
if jax.process_index() == 0:
|
304 |
+
logger.info('Saving checkpoint...')
|
305 |
+
checkpoint.save_checkpoint(config.ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step,
|
306 |
+
epoch)
|
307 |
+
|
308 |
+
num_imgs_processed += num_devices * config.batch_size
|
309 |
+
if step % config.eval_fid_every == 0 and not config.disable_fid:
|
310 |
+
# If FID evaluation is enabled, only save a checkpoint if FID score is better.
|
311 |
+
if jax.process_index() == 0:
|
312 |
+
logger.info('Computing FID...')
|
313 |
+
fid_score = fid_metric.compute_fid(params_ema_G).item()
|
314 |
+
if config.wandb:
|
315 |
+
wandb.log({'training/gen/fid': fid_score}, step=step)
|
316 |
+
logger.info(f'Computed FID: {fid_score:.2f}')
|
317 |
+
if fid_score < best_fid_score:
|
318 |
+
best_fid_score = fid_score
|
319 |
+
logger.info(f'New best FID score ({best_fid_score:.3f}). Saving checkpoint...')
|
320 |
+
ts = time.time()
|
321 |
+
checkpoint.save_checkpoint(config.ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step, epoch, fid_score=fid_score)
|
322 |
+
te = time.time()
|
323 |
+
logger.info(f'... successfully saved checkpoint in {(te-ts)/60:.1f}min')
|
324 |
+
|
325 |
+
sec_per_kimg = (time.time() - iteration_start_time) / (num_devices * config.batch_size / 1000.0)
|
326 |
+
time_taken_g = time_g_end - time_g_start
|
327 |
+
time_taken_d = time_d_end - time_d_start
|
328 |
+
time_taken_per_step = time.time() - iteration_start_time
|
329 |
+
g_loss = jnp.mean(metrics['G_loss']).item()
|
330 |
+
d_loss = jnp.mean(metrics['D_loss']).item()
|
331 |
+
|
332 |
+
if config.wandb and jax.process_index() == 0:
|
333 |
+
# wandb logging - happens every step
|
334 |
+
wandb.log({'training/gen/loss': jnp.mean(metrics['G_loss']).item()}, step=step, commit=False)
|
335 |
+
wandb.log({'training/dis/loss': jnp.mean(metrics['D_loss']).item()}, step=step, commit=False)
|
336 |
+
wandb.log({'training/dis/fake_logits': jnp.mean(metrics['fake_logits']).item()}, step=step, commit=False)
|
337 |
+
wandb.log({'training/dis/real_logits': jnp.mean(metrics['real_logits']).item()}, step=step, commit=False)
|
338 |
+
wandb.log({'training/time_taken_g': time_taken_g, 'training/time_taken_d': time_taken_d}, step=step, commit=False)
|
339 |
+
wandb.log({'training/time_taken_per_step': time_taken_per_step}, step=step, commit=False)
|
340 |
+
wandb.log({'training/num_imgs_trained': num_imgs_processed}, step=step, commit=False)
|
341 |
+
wandb.log({'training/sec_per_kimg': sec_per_kimg}, step=step)
|
342 |
+
|
343 |
+
if step % config.log_every == 0:
|
344 |
+
# console logging - happens every log_every steps
|
345 |
+
logger.info(f'Total steps: {step:>6,} - epoch {epoch:>3,}/{config.num_epochs} @ {step % num_steps_per_epoch:>6,}/{num_steps_per_epoch:,} - G loss: {g_loss:.5f} - D loss: {d_loss:.5f} - sec/kimg: {sec_per_kimg:.2f}s - time per step: {time_taken_per_step:.3f}s')
|
346 |
+
|
347 |
+
if step % config.generate_samples_every == 0 and config.wandb and jax.process_index() == 0:
|
348 |
+
# Generate training images
|
349 |
+
train_snapshot = training_utils.get_training_snapshot(
|
350 |
+
image_real=flax.jax_utils.unreplicate(batch['image']),
|
351 |
+
image_gen=flax.jax_utils.unreplicate(metrics['image_gen']),
|
352 |
+
max_num=10
|
353 |
+
)
|
354 |
+
wandb.log({'training/snapshot': wandb.Image(train_snapshot)}, commit=False, step=step)
|
355 |
+
|
356 |
+
# Generate evaluation images
|
357 |
+
labels = None if config.c_dim == 0 else batch['label'][0]
|
358 |
+
image_gen_eval = training_steps.eval_step_G(
|
359 |
+
generator_ema, params=params_ema_G,
|
360 |
+
z_latent=z_latent1[0],
|
361 |
+
labels=labels,
|
362 |
+
truncation=1
|
363 |
+
)
|
364 |
+
image_gen_eval_trunc = training_steps.eval_step_G(
|
365 |
+
generator_ema,
|
366 |
+
params=params_ema_G,
|
367 |
+
z_latent=z_latent1[0],
|
368 |
+
labels=labels,
|
369 |
+
truncation=0.5
|
370 |
+
)
|
371 |
+
eval_snapshot = training_utils.get_eval_snapshot(image=image_gen_eval, max_num=10)
|
372 |
+
eval_snapshot_trunc = training_utils.get_eval_snapshot(image=image_gen_eval_trunc, max_num=10)
|
373 |
+
wandb.log({'eval/snapshot': wandb.Image(eval_snapshot)}, commit=False, step=step)
|
374 |
+
wandb.log({'eval/snapshot_trunc': wandb.Image(eval_snapshot_trunc)}, step=step)
|
375 |
+
|
376 |
+
step += 1
|
377 |
+
|
378 |
+
# Sync moving stats across devices
|
379 |
+
state_G = training_utils.sync_moving_stats(state_G)
|
380 |
+
|
381 |
+
# Sync moving average of path length mean (Generator regularization)
|
382 |
+
pl_mean = jax.pmap(lambda x: jax.lax.pmean(x, axis_name='batch'), axis_name='batch')(pl_mean)
|
training_steps.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jax
|
2 |
+
import jax.numpy as jnp
|
3 |
+
import functools
|
4 |
+
|
5 |
+
|
6 |
+
def main_step_G(state_G, state_D, batch, z_latent1, z_latent2, metrics, mixing_prob, rng):
|
7 |
+
|
8 |
+
def loss_fn(params):
|
9 |
+
w_latent1, new_state_G = state_G.apply_mapping({'params': params['mapping'], 'moving_stats': state_G.moving_stats},
|
10 |
+
z_latent1,
|
11 |
+
batch['label'],
|
12 |
+
mutable=['moving_stats'])
|
13 |
+
w_latent2 = state_G.apply_mapping({'params': params['mapping'], 'moving_stats': state_G.moving_stats},
|
14 |
+
z_latent2,
|
15 |
+
batch['label'],
|
16 |
+
skip_w_avg_update=True)
|
17 |
+
|
18 |
+
# style mixing
|
19 |
+
cutoff_rng, layer_select_rng, synth_rng = jax.random.split(rng, num=3)
|
20 |
+
num_layers = w_latent1.shape[1]
|
21 |
+
layer_idx = jnp.arange(num_layers)[jnp.newaxis, :, jnp.newaxis]
|
22 |
+
mixing_cutoff = jax.lax.cond(jax.random.uniform(cutoff_rng, (), minval=0.0, maxval=1.0) < mixing_prob,
|
23 |
+
lambda _: jax.random.randint(layer_select_rng, (), 1, num_layers, dtype=jnp.int32),
|
24 |
+
lambda _: num_layers,
|
25 |
+
operand=None)
|
26 |
+
mixing_cond = jnp.broadcast_to(layer_idx < mixing_cutoff, w_latent1.shape)
|
27 |
+
w_latent = jnp.where(mixing_cond, w_latent1, w_latent2)
|
28 |
+
|
29 |
+
image_gen = state_G.apply_synthesis({'params': params['synthesis'], 'noise_consts': state_G.noise_consts},
|
30 |
+
w_latent,
|
31 |
+
rng=synth_rng)
|
32 |
+
|
33 |
+
fake_logits = state_D.apply_fn(state_D.params, image_gen, batch['label'])
|
34 |
+
loss = jnp.mean(jax.nn.softplus(-fake_logits))
|
35 |
+
return loss, (fake_logits, image_gen, new_state_G)
|
36 |
+
|
37 |
+
dynamic_scale = state_G.dynamic_scale_main
|
38 |
+
|
39 |
+
if dynamic_scale:
|
40 |
+
grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True, axis_name='batch')
|
41 |
+
dynamic_scale, is_fin, aux, grads = grad_fn(state_G.params)
|
42 |
+
else:
|
43 |
+
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
44 |
+
aux, grads = grad_fn(state_G.params)
|
45 |
+
grads = jax.lax.pmean(grads, axis_name='batch')
|
46 |
+
|
47 |
+
loss = aux[0]
|
48 |
+
_, image_gen, new_state = aux[1]
|
49 |
+
metrics['G_loss'] = loss
|
50 |
+
metrics['image_gen'] = image_gen
|
51 |
+
|
52 |
+
new_state_G = state_G.apply_gradients(grads=grads, moving_stats=new_state['moving_stats'])
|
53 |
+
|
54 |
+
if dynamic_scale:
|
55 |
+
new_state_G = new_state_G.replace(opt_state=jax.tree_multimap(functools.partial(jnp.where, is_fin),
|
56 |
+
new_state_G.opt_state,
|
57 |
+
state_G.opt_state),
|
58 |
+
params=jax.tree_multimap(functools.partial(jnp.where, is_fin),
|
59 |
+
new_state_G.params,
|
60 |
+
state_G.params))
|
61 |
+
metrics['G_scale'] = dynamic_scale.scale
|
62 |
+
|
63 |
+
return new_state_G, metrics
|
64 |
+
|
65 |
+
|
66 |
+
def regul_step_G(state_G, batch, z_latent, pl_noise, pl_mean, metrics, config, rng):
|
67 |
+
|
68 |
+
def loss_fn(params):
|
69 |
+
w_latent, new_state_G = state_G.apply_mapping({'params': params['mapping'], 'moving_stats': state_G.moving_stats},
|
70 |
+
z_latent,
|
71 |
+
batch['label'],
|
72 |
+
mutable=['moving_stats'])
|
73 |
+
|
74 |
+
pl_grads = jax.grad(lambda *args: jnp.sum(state_G.apply_synthesis(*args) * pl_noise), argnums=1)({'params': params['synthesis'],
|
75 |
+
'noise_consts': state_G.noise_consts},
|
76 |
+
w_latent,
|
77 |
+
'random',
|
78 |
+
rng)
|
79 |
+
pl_lengths = jnp.sqrt(jnp.mean(jnp.sum(jnp.square(pl_grads), axis=2), axis=1))
|
80 |
+
pl_mean_new = pl_mean + config.pl_decay * (jnp.mean(pl_lengths) - pl_mean)
|
81 |
+
pl_penalty = jnp.square(pl_lengths - pl_mean_new) * config.pl_weight
|
82 |
+
loss = jnp.mean(pl_penalty) * config.G_reg_interval
|
83 |
+
|
84 |
+
return loss, pl_mean_new
|
85 |
+
|
86 |
+
dynamic_scale = state_G.dynamic_scale_reg
|
87 |
+
|
88 |
+
if dynamic_scale:
|
89 |
+
grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True)
|
90 |
+
dynamic_scale, is_fin, aux, grads = grad_fn(state_G.params)
|
91 |
+
else:
|
92 |
+
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
93 |
+
aux, grads = grad_fn(state_G.params)
|
94 |
+
grads = jax.lax.pmean(grads, axis_name='batch')
|
95 |
+
|
96 |
+
loss = aux[0]
|
97 |
+
pl_mean_new = aux[1]
|
98 |
+
|
99 |
+
metrics['G_regul_loss'] = loss
|
100 |
+
new_state_G = state_G.apply_gradients(grads=grads)
|
101 |
+
|
102 |
+
if dynamic_scale:
|
103 |
+
new_state_G = new_state_G.replace(opt_state=jax.tree_multimap(functools.partial(jnp.where, is_fin),
|
104 |
+
new_state_G.opt_state,
|
105 |
+
state_G.opt_state),
|
106 |
+
params=jax.tree_multimap(functools.partial(jnp.where, is_fin),
|
107 |
+
new_state_G.params,
|
108 |
+
state_G.params))
|
109 |
+
metrics['G_regul_scale'] = dynamic_scale.scale
|
110 |
+
|
111 |
+
return new_state_G, metrics, pl_mean_new
|
112 |
+
|
113 |
+
|
114 |
+
def main_step_D(state_G, state_D, batch, z_latent1, z_latent2, metrics, mixing_prob, rng):
|
115 |
+
|
116 |
+
def loss_fn(params):
|
117 |
+
w_latent1 = state_G.apply_mapping({'params': state_G.params['mapping'], 'moving_stats': state_G.moving_stats},
|
118 |
+
z_latent1,
|
119 |
+
batch['label'],
|
120 |
+
train=False)
|
121 |
+
|
122 |
+
w_latent2 = state_G.apply_mapping({'params': state_G.params['mapping'], 'moving_stats': state_G.moving_stats},
|
123 |
+
z_latent2,
|
124 |
+
batch['label'],
|
125 |
+
train=False)
|
126 |
+
|
127 |
+
# style mixing
|
128 |
+
cutoff_rng, layer_select_rng, synth_rng = jax.random.split(rng, num=3)
|
129 |
+
num_layers = w_latent1.shape[1]
|
130 |
+
layer_idx = jnp.arange(num_layers)[jnp.newaxis, :, jnp.newaxis]
|
131 |
+
mixing_cutoff = jax.lax.cond(jax.random.uniform(cutoff_rng, (), minval=0.0, maxval=1.0) < mixing_prob,
|
132 |
+
lambda _: jax.random.randint(layer_select_rng, (), 1, num_layers, dtype=jnp.int32),
|
133 |
+
lambda _: num_layers,
|
134 |
+
operand=None)
|
135 |
+
mixing_cond = jnp.broadcast_to(layer_idx < mixing_cutoff, w_latent1.shape)
|
136 |
+
w_latent = jnp.where(mixing_cond, w_latent1, w_latent2)
|
137 |
+
|
138 |
+
image_gen = state_G.apply_synthesis({'params': state_G.params['synthesis'], 'noise_consts': state_G.noise_consts},
|
139 |
+
w_latent,
|
140 |
+
rng=synth_rng)
|
141 |
+
|
142 |
+
fake_logits = state_D.apply_fn(params, image_gen, batch['label'])
|
143 |
+
real_logits = state_D.apply_fn(params, batch['image'], batch['label'])
|
144 |
+
|
145 |
+
loss_fake = jax.nn.softplus(fake_logits)
|
146 |
+
loss_real = jax.nn.softplus(-real_logits)
|
147 |
+
loss = jnp.mean(loss_fake + loss_real)
|
148 |
+
|
149 |
+
return loss, (fake_logits, real_logits)
|
150 |
+
|
151 |
+
dynamic_scale = state_D.dynamic_scale_main
|
152 |
+
|
153 |
+
if dynamic_scale:
|
154 |
+
grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True)
|
155 |
+
dynamic_scale, is_fin, aux, grads = grad_fn(state_D.params)
|
156 |
+
else:
|
157 |
+
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
158 |
+
aux, grads = grad_fn(state_D.params)
|
159 |
+
grads = jax.lax.pmean(grads, axis_name='batch')
|
160 |
+
|
161 |
+
loss = aux[0]
|
162 |
+
fake_logits, real_logits = aux[1]
|
163 |
+
metrics['D_loss'] = loss
|
164 |
+
metrics['fake_logits'] = jnp.mean(fake_logits)
|
165 |
+
metrics['real_logits'] = jnp.mean(real_logits)
|
166 |
+
|
167 |
+
new_state_D = state_D.apply_gradients(grads=grads)
|
168 |
+
|
169 |
+
if dynamic_scale:
|
170 |
+
new_state_D = new_state_D.replace(opt_state=jax.tree_multimap(functools.partial(jnp.where, is_fin),
|
171 |
+
new_state_D.opt_state,
|
172 |
+
state_D.opt_state),
|
173 |
+
params=jax.tree_multimap(functools.partial(jnp.where, is_fin),
|
174 |
+
new_state_D.params,
|
175 |
+
state_D.params))
|
176 |
+
metrics['D_scale'] = dynamic_scale.scale
|
177 |
+
|
178 |
+
return new_state_D, metrics
|
179 |
+
|
180 |
+
|
181 |
+
def regul_step_D(state_D, batch, metrics, config):
|
182 |
+
|
183 |
+
def loss_fn(params):
|
184 |
+
r1_grads = jax.grad(lambda *args: jnp.sum(state_D.apply_fn(*args)), argnums=1)(params, batch['image'], batch['label'])
|
185 |
+
r1_penalty = jnp.sum(jnp.square(r1_grads), axis=(1, 2, 3)) * (config.r1_gamma / 2) * config.D_reg_interval
|
186 |
+
loss = jnp.mean(r1_penalty)
|
187 |
+
return loss, None
|
188 |
+
|
189 |
+
dynamic_scale = state_D.dynamic_scale_reg
|
190 |
+
|
191 |
+
if dynamic_scale:
|
192 |
+
grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True)
|
193 |
+
dynamic_scale, is_fin, aux, grads = grad_fn(state_D.params)
|
194 |
+
else:
|
195 |
+
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
196 |
+
aux, grads = grad_fn(state_D.params)
|
197 |
+
grads = jax.lax.pmean(grads, axis_name='batch')
|
198 |
+
|
199 |
+
loss = aux[0]
|
200 |
+
metrics['D_regul_loss'] = loss
|
201 |
+
|
202 |
+
new_state_D = state_D.apply_gradients(grads=grads)
|
203 |
+
|
204 |
+
if dynamic_scale:
|
205 |
+
new_state_D = new_state_D.replace(opt_state=jax.tree_multimap(functools.partial(jnp.where, is_fin),
|
206 |
+
new_state_D.opt_state,
|
207 |
+
state_D.opt_state),
|
208 |
+
params=jax.tree_multimap(functools.partial(jnp.where, is_fin),
|
209 |
+
new_state_D.params,
|
210 |
+
state_D.params))
|
211 |
+
metrics['D_regul_scale'] = dynamic_scale.scale
|
212 |
+
|
213 |
+
return new_state_D, metrics
|
214 |
+
|
215 |
+
|
216 |
+
def eval_step_G(generator, params, z_latent, labels, truncation):
|
217 |
+
image_gen = generator.apply(params, z_latent, labels, truncation_psi=truncation, train=False, noise_mode='const')
|
218 |
+
return image_gen
|
219 |
+
|
training_utils.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jax
|
2 |
+
import jax.numpy as jnp
|
3 |
+
from jaxlib.xla_extension import DeviceArray
|
4 |
+
import flax
|
5 |
+
from flax.optim import dynamic_scale as dynamic_scale_lib
|
6 |
+
from flax.core import frozen_dict
|
7 |
+
from flax.training import train_state
|
8 |
+
from flax import struct
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from urllib.request import Request, urlopen
|
12 |
+
import urllib.error
|
13 |
+
from typing import Any, Callable
|
14 |
+
|
15 |
+
|
16 |
+
def sync_moving_stats(state):
|
17 |
+
"""
|
18 |
+
Sync moving statistics across devices.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
state (train_state.TrainState): Training state.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
(train_state.TrainState): Updated training state.
|
25 |
+
"""
|
26 |
+
cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, 'x'), 'x')
|
27 |
+
return state.replace(moving_stats=cross_replica_mean(state.moving_stats))
|
28 |
+
|
29 |
+
|
30 |
+
def update_generator_ema(state_G, params_ema_G, config, ema_beta=None):
|
31 |
+
"""
|
32 |
+
Update exponentially moving average of the generator weights.
|
33 |
+
Moving stats and noise constants will be copied over.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
state_G (train_state.TrainState): Generator state.
|
37 |
+
params_ema_G (frozen_dict.FrozenDict): Parameters of the ema generator.
|
38 |
+
config (Any): Config object.
|
39 |
+
ema_beta (float): Beta parameter of the ema. If None, will be computed
|
40 |
+
from 'ema_nimg' and 'batch_size'.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
(frozen_dict.FrozenDict): Updates parameters of the ema generator.
|
44 |
+
"""
|
45 |
+
def _update_ema(src, trg, beta):
|
46 |
+
for name, src_child in src.items():
|
47 |
+
if isinstance(src_child, DeviceArray):
|
48 |
+
trg[name] = src[name] + ema_beta * (trg[name] - src[name])
|
49 |
+
else:
|
50 |
+
_update_ema(src_child, trg[name], beta)
|
51 |
+
|
52 |
+
if ema_beta is None:
|
53 |
+
ema_nimg = config.ema_kimg * 1000
|
54 |
+
ema_beta = 0.5 ** (config.batch_size / max(ema_nimg, 1e-8))
|
55 |
+
|
56 |
+
params_ema_G = params_ema_G.unfreeze()
|
57 |
+
|
58 |
+
# Copy over moving stats
|
59 |
+
params_ema_G['moving_stats']['mapping_network'] = state_G.moving_stats
|
60 |
+
params_ema_G['noise_consts']['synthesis_network'] = state_G.noise_consts
|
61 |
+
|
62 |
+
# Update exponentially moving average of the trainable parameters
|
63 |
+
_update_ema(state_G.params['mapping'], params_ema_G['params']['mapping_network'], ema_beta)
|
64 |
+
_update_ema(state_G.params['synthesis'], params_ema_G['params']['synthesis_network'], ema_beta)
|
65 |
+
|
66 |
+
params_ema_G = frozen_dict.freeze(params_ema_G)
|
67 |
+
return params_ema_G
|
68 |
+
|
69 |
+
|
70 |
+
class TrainStateG(train_state.TrainState):
|
71 |
+
"""
|
72 |
+
Generator train state for a single Optax optimizer.
|
73 |
+
|
74 |
+
Attributes:
|
75 |
+
apply_mapping (Callable): Apply function of the Mapping Network.
|
76 |
+
apply_synthesis (Callable): Apply function of the Synthesis Network.
|
77 |
+
dynamic_scale (dynamic_scale_lib.DynamicScale): Dynamic loss scaling for mixed precision gradients.
|
78 |
+
epoch (int): Current epoch.
|
79 |
+
moving_stats (Any): Moving average of the latent W.
|
80 |
+
noise_consts (Any): Noise constants from synthesis layers.
|
81 |
+
"""
|
82 |
+
apply_mapping: Callable = struct.field(pytree_node=False)
|
83 |
+
apply_synthesis: Callable = struct.field(pytree_node=False)
|
84 |
+
dynamic_scale_main: dynamic_scale_lib.DynamicScale
|
85 |
+
dynamic_scale_reg: dynamic_scale_lib.DynamicScale
|
86 |
+
epoch: int
|
87 |
+
moving_stats: Any=None
|
88 |
+
noise_consts: Any=None
|
89 |
+
|
90 |
+
|
91 |
+
class TrainStateD(train_state.TrainState):
|
92 |
+
"""
|
93 |
+
Discriminator train state for a single Optax optimizer.
|
94 |
+
|
95 |
+
Attributes:
|
96 |
+
dynamic_scale (dynamic_scale_lib.DynamicScale): Dynamic loss scaling for mixed precision gradients.
|
97 |
+
epoch (int): Current epoch.
|
98 |
+
"""
|
99 |
+
dynamic_scale_main: dynamic_scale_lib.DynamicScale
|
100 |
+
dynamic_scale_reg: dynamic_scale_lib.DynamicScale
|
101 |
+
epoch: int
|
102 |
+
|
103 |
+
|
104 |
+
def get_training_snapshot(image_real, image_gen, max_num=10):
|
105 |
+
"""
|
106 |
+
Creates a snapshot of generated images and real images.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
images_real (DeviceArray): Batch of real images, shape [B, H, W, C].
|
110 |
+
images_gen (DeviceArray): Batch of generated images, shape [B, H, W, C].
|
111 |
+
max_num (int): Maximum number of images used for snapshot.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
(PIL.Image): Training snapshot. Top row: generated images, bottom row: real images.
|
115 |
+
"""
|
116 |
+
if image_real.shape[0] > max_num:
|
117 |
+
image_real = image_real[:max_num]
|
118 |
+
if image_gen.shape[0] > max_num:
|
119 |
+
image_gen = image_gen[:max_num]
|
120 |
+
|
121 |
+
image_real = jnp.split(image_real, image_real.shape[0], axis=0)
|
122 |
+
image_gen = jnp.split(image_gen, image_gen.shape[0], axis=0)
|
123 |
+
|
124 |
+
image_real = [jnp.squeeze(x, axis=0) for x in image_real]
|
125 |
+
image_gen = [jnp.squeeze(x, axis=0) for x in image_gen]
|
126 |
+
|
127 |
+
image_real = jnp.concatenate(image_real, axis=1)
|
128 |
+
image_gen = jnp.concatenate(image_gen, axis=1)
|
129 |
+
|
130 |
+
image_gen = (image_gen - np.min(image_gen)) / (np.max(image_gen) - np.min(image_gen))
|
131 |
+
image_real = (image_real - np.min(image_real)) / (np.max(image_real) - np.min(image_real))
|
132 |
+
image = jnp.concatenate((image_gen, image_real), axis=0)
|
133 |
+
|
134 |
+
image = np.uint8(image * 255)
|
135 |
+
if image.shape[-1] == 1:
|
136 |
+
image = np.repeat(image, 3, axis=-1)
|
137 |
+
return Image.fromarray(image)
|
138 |
+
|
139 |
+
|
140 |
+
def get_eval_snapshot(image, max_num=10):
|
141 |
+
"""
|
142 |
+
Creates a snapshot of generated images.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
image (DeviceArray): Generated images, shape [B, H, W, C].
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
(PIL.Image): Eval snapshot.
|
149 |
+
"""
|
150 |
+
if image.shape[0] > max_num:
|
151 |
+
image = image[:max_num]
|
152 |
+
|
153 |
+
image = jnp.split(image, image.shape[0], axis=0)
|
154 |
+
image = [jnp.squeeze(x, axis=0) for x in image]
|
155 |
+
image = jnp.concatenate(image, axis=1)
|
156 |
+
image = (image - np.min(image)) / (np.max(image) - np.min(image))
|
157 |
+
image = np.uint8(image * 255)
|
158 |
+
if image.shape[-1] == 1:
|
159 |
+
image = np.repeat(image, 3, axis=-1)
|
160 |
+
return Image.fromarray(image)
|
161 |
+
|
162 |
+
|
163 |
+
def get_vm_name():
|
164 |
+
gcp_metadata_url = "http://metadata.google.internal/computeMetadata/v1/instance/attributes/instance-id"
|
165 |
+
req = Request(gcp_metadata_url)
|
166 |
+
req.add_header('Metadata-Flavor', 'Google')
|
167 |
+
instance_id = None
|
168 |
+
try:
|
169 |
+
with urlopen(req) as url:
|
170 |
+
instance_id = url.read().decode()
|
171 |
+
except urllib.error.URLError:
|
172 |
+
# metadata.google.internal not reachable: use dev
|
173 |
+
pass
|
174 |
+
return instance_id
|