import torch def sample_x0(x1): """Sampling x0 & t based on shape of x1 (if needed) Args: x1 - data point; [batch, *dim] """ if isinstance(x1, (list, tuple)): x0 = [torch.randn_like(img_start) for img_start in x1] else: x0 = torch.randn_like(x1) return x0 def sample_timestep(x1): u = torch.normal(mean=0.0, std=1.0, size=(len(x1),)) t = 1 / (1 + torch.exp(-u)) t = t.to(x1[0]) return t def training_losses(model, x1, model_kwargs=None, snr_type='uniform'): """Loss for training torche score model Args: - model: backbone model; could be score, noise, or velocity - x1: datapoint - model_kwargs: additional arguments for torche model """ if model_kwargs == None: model_kwargs = {} B = len(x1) x0 = sample_x0(x1) t = sample_timestep(x1) if isinstance(x1, (list, tuple)): xt = [t[i] * x1[i] + (1 - t[i]) * x0[i] for i in range(B)] ut = [x1[i] - x0[i] for i in range(B)] else: dims = [1] * (len(x1.size()) - 1) t_ = t.view(t.size(0), *dims) xt = t_ * x1 + (1 - t_) * x0 ut = x1 - x0 model_output = model(xt, t, **model_kwargs) terms = {} if isinstance(x1, (list, tuple)): assert len(model_output) == len(ut) == len(x1) for i in range(B): terms["loss"] = torch.stack( [((ut[i] - model_output[i]) ** 2).mean() for i in range(B)], dim=0, ) else: terms["loss"] = mean_flat(((model_output - ut) ** 2)) return terms def mean_flat(x): """ Take torche mean over all non-batch dimensions. """ return torch.mean(x, dim=list(range(1, len(x.size()))))