import jax import jax.numpy as jnp import flax from flax.optim import dynamic_scale as dynamic_scale_lib from flax.core import frozen_dict import optax import numpy as np import functools import wandb import time import stylegan2 import data_pipeline import checkpoint import training_utils import training_steps from fid import FID import logging logger = logging.getLogger(__name__) def tree_shape(item): return jax.tree_map(lambda c: c.shape, item) def train_and_evaluate(config): num_devices = jax.device_count() # 8 num_local_devices = jax.local_device_count() # 4 num_workers = jax.process_count() # -------------------------------------- # Data # -------------------------------------- ds_train, dataset_info = data_pipeline.get_data(data_dir=config.data_dir, img_size=config.resolution, img_channels=config.img_channels, num_classes=config.c_dim, num_local_devices=num_local_devices, batch_size=config.batch_size) # -------------------------------------- # Seeding and Precision # -------------------------------------- rng = jax.random.PRNGKey(config.random_seed) if config.mixed_precision: dtype = jnp.float16 elif config.bf16: dtype = jnp.bfloat16 else: dtype = jnp.float32 logger.info(f'Running on dtype {dtype}') platform = jax.local_devices()[0].platform if config.mixed_precision and platform == 'gpu': dynamic_scale_G_main = dynamic_scale_lib.DynamicScale() dynamic_scale_D_main = dynamic_scale_lib.DynamicScale() dynamic_scale_G_reg = dynamic_scale_lib.DynamicScale() dynamic_scale_D_reg = dynamic_scale_lib.DynamicScale() clip_conv = 256 num_fp16_res = 4 else: dynamic_scale_G_main = None dynamic_scale_D_main = None dynamic_scale_G_reg = None dynamic_scale_D_reg = None clip_conv = None num_fp16_res = 0 # -------------------------------------- # Initialize Models # -------------------------------------- logger.info('Initialize models...') rng, init_rng = jax.random.split(rng) # Generator initialization for training start_mn = time.time() logger.info("Creating MappingNetwork...") mapping_net = stylegan2.MappingNetwork(z_dim=config.z_dim, c_dim=config.c_dim, w_dim=config.w_dim, num_ws=int(np.log2(config.resolution)) * 2 - 3, num_layers=8, dtype=dtype) mapping_net_vars = mapping_net.init(init_rng, jnp.ones((1, config.z_dim)), jnp.ones((1, config.c_dim))) mapping_net_params, moving_stats = mapping_net_vars['params'], mapping_net_vars['moving_stats'] logger.info(f"MappingNetwork took {time.time() - start_mn:.2f}s") logger.info("Creating SynthesisNetwork...") start_sn = time.time() synthesis_net = stylegan2.SynthesisNetwork(resolution=config.resolution, num_channels=config.img_channels, w_dim=config.w_dim, fmap_base=config.fmap_base, num_fp16_res=num_fp16_res, clip_conv=clip_conv, dtype=dtype) synthesis_net_vars = synthesis_net.init(init_rng, jnp.ones((1, mapping_net.num_ws, config.w_dim))) synthesis_net_params, noise_consts = synthesis_net_vars['params'], synthesis_net_vars['noise_consts'] logger.info(f"SynthesisNetwork took {time.time() - start_sn:.2f}s") params_G = frozen_dict.FrozenDict( {'mapping': mapping_net_params, 'synthesis': synthesis_net_params} ) # Discriminator initialization for training logger.info("Creating Discriminator...") start_d = time.time() discriminator = stylegan2.Discriminator(resolution=config.resolution, num_channels=config.img_channels, c_dim=config.c_dim, mbstd_group_size=config.mbstd_group_size, num_fp16_res=num_fp16_res, clip_conv=clip_conv, dtype=dtype) rng, init_rng = jax.random.split(rng) params_D = discriminator.init(init_rng, jnp.ones((1, config.resolution, config.resolution, config.img_channels)), jnp.ones((1, config.c_dim))) logger.info(f"Discriminator took {time.time() - start_d:.2f}s") # Exponential average Generator initialization logger.info("Creating Generator EMA...") start_g = time.time() generator_ema = stylegan2.Generator(resolution=config.resolution, num_channels=config.img_channels, z_dim=config.z_dim, c_dim=config.c_dim, w_dim=config.w_dim, num_ws=int(np.log2(config.resolution)) * 2 - 3, num_mapping_layers=8, fmap_base=config.fmap_base, num_fp16_res=num_fp16_res, clip_conv=clip_conv, dtype=dtype) params_ema_G = generator_ema.init(init_rng, jnp.ones((1, config.z_dim)), jnp.ones((1, config.c_dim))) logger.info(f"Took {time.time() - start_g:.2f}s") # -------------------------------------- # Initialize States and Optimizers # -------------------------------------- logger.info('Initialize states...') tx_G = optax.adam(learning_rate=config.learning_rate, b1=0.0, b2=0.99) tx_D = optax.adam(learning_rate=config.learning_rate, b1=0.0, b2=0.99) state_G = training_utils.TrainStateG.create(apply_fn=None, apply_mapping=mapping_net.apply, apply_synthesis=synthesis_net.apply, params=params_G, moving_stats=moving_stats, noise_consts=noise_consts, tx=tx_G, dynamic_scale_main=dynamic_scale_G_main, dynamic_scale_reg=dynamic_scale_G_reg, epoch=0) state_D = training_utils.TrainStateD.create(apply_fn=discriminator.apply, params=params_D, tx=tx_D, dynamic_scale_main=dynamic_scale_D_main, dynamic_scale_reg=dynamic_scale_D_reg, epoch=0) # Copy over the parameters from the training generator to the ema generator params_ema_G = training_utils.update_generator_ema(state_G, params_ema_G, config, ema_beta=0) # Running mean of path length for path length regularization pl_mean = jnp.zeros((), dtype=dtype) step = 0 epoch_offset = 0 best_fid_score = np.inf ckpt_path = None if config.resume_run_id is not None: # Resume training from existing checkpoint ckpt_path = checkpoint.get_latest_checkpoint(config.ckpt_dir) logger.info(f'Resume training from checkpoint: {ckpt_path}') ckpt = checkpoint.load_checkpoint(ckpt_path) step = ckpt['step'] epoch_offset = ckpt['epoch'] best_fid_score = ckpt['fid_score'] pl_mean = ckpt['pl_mean'] state_G = ckpt['state_G'] state_D = ckpt['state_D'] params_ema_G = ckpt['params_ema_G'] config = ckpt['config'] elif config.load_from_pkl is not None: # Load checkpoint and start new run ckpt_path = config.load_from_pkl logger.info(f'Load model state from from : {ckpt_path}') ckpt = checkpoint.load_checkpoint(ckpt_path) pl_mean = ckpt['pl_mean'] state_G = ckpt['state_G'] state_D = ckpt['state_D'] params_ema_G = ckpt['params_ema_G'] # Replicate states across devices pl_mean = flax.jax_utils.replicate(pl_mean) state_G = flax.jax_utils.replicate(state_G) state_D = flax.jax_utils.replicate(state_D) # -------------------------------------- # Precompile train and eval steps # -------------------------------------- logger.info('Precompile training steps...') p_main_step_G = jax.pmap(training_steps.main_step_G, axis_name='batch') p_regul_step_G = jax.pmap(functools.partial(training_steps.regul_step_G, config=config), axis_name='batch') p_main_step_D = jax.pmap(training_steps.main_step_D, axis_name='batch') p_regul_step_D = jax.pmap(functools.partial(training_steps.regul_step_D, config=config), axis_name='batch') # -------------------------------------- # Training # -------------------------------------- logger.info('Start training...') fid_metric = FID(generator_ema, ds_train, config) # Dict to collect training statistics / losses metrics = {} num_imgs_processed = 0 num_steps_per_epoch = dataset_info['num_examples'] // (config.batch_size * num_devices) effective_batch_size = config.batch_size * num_devices if config.wandb and jax.process_index() == 0: # do some more logging wandb.config.effective_batch_size = effective_batch_size wandb.config.num_steps_per_epoch = num_steps_per_epoch wandb.config.num_workers = num_workers wandb.config.device_count = num_devices wandb.config.num_examples = dataset_info['num_examples'] wandb.config.vm_name = training_utils.get_vm_name() for epoch in range(epoch_offset, config.num_epochs): if config.wandb and jax.process_index() == 0: wandb.log({'training/epochs': epoch}, step=step) for batch in data_pipeline.prefetch(ds_train, config.num_prefetch): assert batch['image'].shape[1] == config.batch_size, f"Mismatched batch (batch size: {config.batch_size}, this batch: {batch['image'].shape[1]})" # pbar.update(num_devices * config.batch_size) iteration_start_time = time.time() if config.c_dim == 0: # No labels in the dataset batch['label'] = None # Create two latent noise vectors and combine them for the style mixing regularization rng, key = jax.random.split(rng) z_latent1 = jax.random.normal(key, (num_local_devices, config.batch_size, config.z_dim), dtype) rng, key = jax.random.split(rng) z_latent2 = jax.random.normal(key, (num_local_devices, config.batch_size, config.z_dim), dtype) # Split PRNGs across devices rkey = jax.random.split(key, num=num_local_devices) mixing_prob = flax.jax_utils.replicate(config.mixing_prob) # -------------------------------------- # Update Discriminator # -------------------------------------- time_d_start = time.time() state_D, metrics = p_main_step_D(state_G, state_D, batch, z_latent1, z_latent2, metrics, mixing_prob, rkey) time_d_end = time.time() if step % config.D_reg_interval == 0: state_D, metrics = p_regul_step_D(state_D, batch, metrics) # -------------------------------------- # Update Generator # -------------------------------------- time_g_start = time.time() state_G, metrics = p_main_step_G(state_G, state_D, batch, z_latent1, z_latent2, metrics, mixing_prob, rkey) if step % config.G_reg_interval == 0: H, W = batch['image'].shape[-3], batch['image'].shape[-2] rng, key = jax.random.split(rng) pl_noise = jax.random.normal(key, batch['image'].shape, dtype=dtype) / np.sqrt(H * W) state_G, metrics, pl_mean = p_regul_step_G(state_G, batch, z_latent1, pl_noise, pl_mean, metrics, rng=rkey) params_ema_G = training_utils.update_generator_ema(flax.jax_utils.unreplicate(state_G), params_ema_G, config) time_g_end = time.time() # -------------------------------------- # Logging and Checkpointing # -------------------------------------- if step % config.save_every == 0 and config.disable_fid: # If FID evaluation is disabled, a checkpoint will be saved every 'save_every' steps. if jax.process_index() == 0: logger.info('Saving checkpoint...') checkpoint.save_checkpoint(config.ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step, epoch) num_imgs_processed += num_devices * config.batch_size if step % config.eval_fid_every == 0 and not config.disable_fid: # If FID evaluation is enabled, only save a checkpoint if FID score is better. if jax.process_index() == 0: logger.info('Computing FID...') fid_score = fid_metric.compute_fid(params_ema_G).item() if config.wandb: wandb.log({'training/gen/fid': fid_score}, step=step) logger.info(f'Computed FID: {fid_score:.2f}') if fid_score < best_fid_score: best_fid_score = fid_score logger.info(f'New best FID score ({best_fid_score:.3f}). Saving checkpoint...') ts = time.time() checkpoint.save_checkpoint(config.ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step, epoch, fid_score=fid_score) te = time.time() logger.info(f'... successfully saved checkpoint in {(te-ts)/60:.1f}min') sec_per_kimg = (time.time() - iteration_start_time) / (num_devices * config.batch_size / 1000.0) time_taken_g = time_g_end - time_g_start time_taken_d = time_d_end - time_d_start time_taken_per_step = time.time() - iteration_start_time g_loss = jnp.mean(metrics['G_loss']).item() d_loss = jnp.mean(metrics['D_loss']).item() if config.wandb and jax.process_index() == 0: # wandb logging - happens every step wandb.log({'training/gen/loss': jnp.mean(metrics['G_loss']).item()}, step=step, commit=False) wandb.log({'training/dis/loss': jnp.mean(metrics['D_loss']).item()}, step=step, commit=False) wandb.log({'training/dis/fake_logits': jnp.mean(metrics['fake_logits']).item()}, step=step, commit=False) wandb.log({'training/dis/real_logits': jnp.mean(metrics['real_logits']).item()}, step=step, commit=False) wandb.log({'training/time_taken_g': time_taken_g, 'training/time_taken_d': time_taken_d}, step=step, commit=False) wandb.log({'training/time_taken_per_step': time_taken_per_step}, step=step, commit=False) wandb.log({'training/num_imgs_trained': num_imgs_processed}, step=step, commit=False) wandb.log({'training/sec_per_kimg': sec_per_kimg}, step=step) if step % config.log_every == 0: # console logging - happens every log_every steps logger.info(f'Total steps: {step:>6,} - epoch {epoch:>3,}/{config.num_epochs} @ {step % num_steps_per_epoch:>6,}/{num_steps_per_epoch:,} - G loss: {g_loss:.5f} - D loss: {d_loss:.5f} - sec/kimg: {sec_per_kimg:.2f}s - time per step: {time_taken_per_step:.3f}s') if step % config.generate_samples_every == 0 and config.wandb and jax.process_index() == 0: # Generate training images train_snapshot = training_utils.get_training_snapshot( image_real=flax.jax_utils.unreplicate(batch['image']), image_gen=flax.jax_utils.unreplicate(metrics['image_gen']), max_num=10 ) wandb.log({'training/snapshot': wandb.Image(train_snapshot)}, commit=False, step=step) # Generate evaluation images labels = None if config.c_dim == 0 else batch['label'][0] image_gen_eval = training_steps.eval_step_G( generator_ema, params=params_ema_G, z_latent=z_latent1[0], labels=labels, truncation=1 ) image_gen_eval_trunc = training_steps.eval_step_G( generator_ema, params=params_ema_G, z_latent=z_latent1[0], labels=labels, truncation=0.5 ) eval_snapshot = training_utils.get_eval_snapshot(image=image_gen_eval, max_num=10) eval_snapshot_trunc = training_utils.get_eval_snapshot(image=image_gen_eval_trunc, max_num=10) wandb.log({'eval/snapshot': wandb.Image(eval_snapshot)}, commit=False, step=step) wandb.log({'eval/snapshot_trunc': wandb.Image(eval_snapshot_trunc)}, step=step) step += 1 # Sync moving stats across devices state_G = training_utils.sync_moving_stats(state_G) # Sync moving average of path length mean (Generator regularization) pl_mean = jax.pmap(lambda x: jax.lax.pmean(x, axis_name='batch'), axis_name='batch')(pl_mean)