|
import torch
|
|
|
|
|
|
def sample_x0(x1):
|
|
"""Sampling x0 & t based on shape of x1 (if needed)
|
|
Args:
|
|
x1 - data point; [batch, *dim]
|
|
"""
|
|
if isinstance(x1, (list, tuple)):
|
|
x0 = [torch.randn_like(img_start) for img_start in x1]
|
|
else:
|
|
x0 = torch.randn_like(x1)
|
|
|
|
return x0
|
|
|
|
def sample_timestep(x1):
|
|
u = torch.normal(mean=0.0, std=1.0, size=(len(x1),))
|
|
t = 1 / (1 + torch.exp(-u))
|
|
t = t.to(x1[0])
|
|
return t
|
|
|
|
|
|
def training_losses(model, x1, model_kwargs=None, snr_type='uniform'):
|
|
"""Loss for training torche score model
|
|
Args:
|
|
- model: backbone model; could be score, noise, or velocity
|
|
- x1: datapoint
|
|
- model_kwargs: additional arguments for torche model
|
|
"""
|
|
if model_kwargs == None:
|
|
model_kwargs = {}
|
|
|
|
B = len(x1)
|
|
|
|
x0 = sample_x0(x1)
|
|
t = sample_timestep(x1)
|
|
|
|
if isinstance(x1, (list, tuple)):
|
|
xt = [t[i] * x1[i] + (1 - t[i]) * x0[i] for i in range(B)]
|
|
ut = [x1[i] - x0[i] for i in range(B)]
|
|
else:
|
|
dims = [1] * (len(x1.size()) - 1)
|
|
t_ = t.view(t.size(0), *dims)
|
|
xt = t_ * x1 + (1 - t_) * x0
|
|
ut = x1 - x0
|
|
|
|
model_output = model(xt, t, **model_kwargs)
|
|
|
|
terms = {}
|
|
|
|
if isinstance(x1, (list, tuple)):
|
|
assert len(model_output) == len(ut) == len(x1)
|
|
for i in range(B):
|
|
terms["loss"] = torch.stack(
|
|
[((ut[i] - model_output[i]) ** 2).mean() for i in range(B)],
|
|
dim=0,
|
|
)
|
|
else:
|
|
terms["loss"] = mean_flat(((model_output - ut) ** 2))
|
|
|
|
return terms
|
|
|
|
|
|
def mean_flat(x):
|
|
"""
|
|
Take torche mean over all non-batch dimensions.
|
|
"""
|
|
return torch.mean(x, dim=list(range(1, len(x.size()))))
|
|
|