import torch import time import torch.optim as optim from collections import OrderedDict from utils.utils import print_current_loss from os.path import join as pjoin from diffusers import DDPMScheduler from torch.utils.tensorboard import SummaryWriter import time import pdb import sys import os from torch.optim.lr_scheduler import ExponentialLR class DDPMTrainer(object): def __init__(self, args, model, accelerator, model_ema=None): self.opt = args self.accelerator = accelerator self.device = self.accelerator.device self.model = model self.diffusion_steps = args.diffusion_steps self.noise_scheduler = DDPMScheduler( num_train_timesteps=self.diffusion_steps, beta_schedule=args.beta_schedule, variance_type="fixed_small", prediction_type=args.prediction_type, clip_sample=False, ) self.model_ema = model_ema if args.is_train: self.mse_criterion = torch.nn.MSELoss(reduction="none") accelerator.print("Diffusion_config:\n", self.noise_scheduler.config) if self.accelerator.is_main_process: starttime = time.strftime("%Y-%m-%d_%H:%M:%S") print("Start experiment:", starttime) self.writer = SummaryWriter( log_dir=pjoin(args.save_root, "logs_") + starttime[:16], comment=starttime[:16], flush_secs=60, ) self.accelerator.wait_for_everyone() self.optimizer = optim.AdamW( self.model.parameters(), lr=self.opt.lr, weight_decay=self.opt.weight_decay ) self.scheduler = ( ExponentialLR(self.optimizer, gamma=args.decay_rate) if args.decay_rate > 0 else None ) @staticmethod def zero_grad(opt_list): for opt in opt_list: opt.zero_grad() def clip_norm(self, network_list): for network in network_list: self.accelerator.clip_grad_norm_( network.parameters(), self.opt.clip_grad_norm ) # 0.5 -> 1 @staticmethod def step(opt_list): for opt in opt_list: opt.step() def forward(self, batch_data): caption, motions, m_lens = batch_data motions = motions.detach().float() x_start = motions B, T = x_start.shape[:2] cur_len = torch.LongTensor([min(T, m_len) for m_len in m_lens]).to(self.device) self.src_mask = self.generate_src_mask(T, cur_len).to(x_start.device) # 1. Sample noise that we'll add to the motion real_noise = torch.randn_like(x_start) # 2. Sample a random timestep for each motion t = torch.randint(0, self.diffusion_steps, (B,), device=self.device) self.timesteps = t # 3. Add noise to the motion according to the noise magnitude at each timestep # (this is the forward diffusion process) x_t = self.noise_scheduler.add_noise(x_start, real_noise, t) # 4. network prediction self.prediction = self.model(x_t, t, text=caption) if self.opt.prediction_type == "sample": self.target = x_start elif self.opt.prediction_type == "epsilon": self.target = real_noise elif self.opt.prediction_type == "v_prediction": self.target = self.noise_scheduler.get_velocity(x_start, real_noise, t) def masked_l2(self, a, b, mask, weights): loss = self.mse_criterion(a, b).mean(dim=-1) # (bath_size, motion_length) loss = (loss * mask).sum(-1) / mask.sum(-1) # (batch_size, ) loss = (loss * weights).mean() return loss def backward_G(self): loss_logs = OrderedDict({}) mse_loss_weights = torch.ones_like(self.timesteps) loss_logs["loss_mot_rec"] = self.masked_l2( self.prediction, self.target, self.src_mask, mse_loss_weights ) self.loss = loss_logs["loss_mot_rec"] return loss_logs def update(self): self.zero_grad([self.optimizer]) loss_logs = self.backward_G() self.accelerator.backward(self.loss) self.clip_norm([self.model]) self.step([self.optimizer]) return loss_logs def generate_src_mask(self, T, length): B = len(length) src_mask = torch.ones(B, T) for i in range(B): for j in range(length[i], T): src_mask[i, j] = 0 return src_mask def train_mode(self): self.model.train() if self.model_ema: self.model_ema.train() def eval_mode(self): self.model.eval() if self.model_ema: self.model_ema.eval() def save(self, file_name, total_it): state = { "opt_encoder": self.optimizer.state_dict(), "total_it": total_it, "encoder": self.accelerator.unwrap_model(self.model).state_dict(), } if self.model_ema: state["model_ema"] = self.accelerator.unwrap_model( self.model_ema ).module.state_dict() torch.save(state, file_name) return def load(self, model_dir): checkpoint = torch.load(model_dir, map_location=self.device) self.optimizer.load_state_dict(checkpoint["opt_encoder"]) if self.model_ema: self.model_ema.load_state_dict(checkpoint["model_ema"], strict=True) self.model.load_state_dict(checkpoint["encoder"], strict=True) return checkpoint.get("total_it", 0) def train(self, train_loader): it = 0 if self.opt.is_continue: model_path = pjoin(self.opt.model_dir, self.opt.continue_ckpt) it = self.load(model_path) self.accelerator.print(f"continue train from {it} iters in {model_path}") start_time = time.time() logs = OrderedDict() self.dataset = train_loader.dataset self.model, self.mse_criterion, self.optimizer, train_loader, self.model_ema = ( self.accelerator.prepare( self.model, self.mse_criterion, self.optimizer, train_loader, self.model_ema, ) ) num_epochs = (self.opt.num_train_steps - it) // len(train_loader) + 1 self.accelerator.print(f"need to train for {num_epochs} epochs....") for epoch in range(0, num_epochs): self.train_mode() for i, batch_data in enumerate(train_loader): self.forward(batch_data) log_dict = self.update() it += 1 if self.model_ema and it % self.opt.model_ema_steps == 0: self.accelerator.unwrap_model(self.model_ema).update_parameters( self.model ) # update logger for k, v in log_dict.items(): if k not in logs: logs[k] = v else: logs[k] += v if it % self.opt.log_every == 0: mean_loss = OrderedDict({}) for tag, value in logs.items(): mean_loss[tag] = value / self.opt.log_every logs = OrderedDict() print_current_loss( self.accelerator, start_time, it, mean_loss, epoch, inner_iter=i ) if self.accelerator.is_main_process: self.writer.add_scalar("loss", mean_loss["loss_mot_rec"], it) self.accelerator.wait_for_everyone() if ( it % self.opt.save_interval == 0 and self.accelerator.is_main_process ): # Save model self.save(pjoin(self.opt.model_dir, "latest.tar").format(it), it) self.accelerator.wait_for_everyone() if (self.scheduler is not None) and ( it % self.opt.update_lr_steps == 0 ): self.scheduler.step() # Save the last checkpoint if it wasn't already saved. if it % self.opt.save_interval != 0 and self.accelerator.is_main_process: self.save(pjoin(self.opt.model_dir, "latest.tar"), it) self.accelerator.wait_for_everyone() self.accelerator.print("FINISH")