LiruiZhao commited on
Commit
887de2c
1 Parent(s): 30e6374

clean up irrelevant code

Browse files
configs/generate_diffree.yaml CHANGED
@@ -1,6 +1,3 @@
1
- # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
2
- # See more details in LICENSE.
3
-
4
  model:
5
  base_learning_rate: 5.0e-05
6
  target: ldm.models.diffusion.ddpm_diffree.LatentDiffusion
 
 
 
 
1
  model:
2
  base_learning_rate: 5.0e-05
3
  target: ldm.models.diffusion.ddpm_diffree.LatentDiffusion
stable_diffusion/ldm/models/diffusion/ddpm_diffree.py CHANGED
@@ -6,9 +6,6 @@ https://github.com/CompVis/taming-transformers
6
  -- merci
7
  """
8
 
9
- # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
10
- # See more details in LICENSE.
11
-
12
  import torch
13
  import torch.nn as nn
14
  import torch.nn.functional as F
 
6
  -- merci
7
  """
8
 
 
 
 
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
stable_diffusion/ldm/models/diffusion/ddpm_edit.py DELETED
@@ -1,1462 +0,0 @@
1
- """
2
- wild mixture of
3
- https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
- https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
- https://github.com/CompVis/taming-transformers
6
- -- merci
7
- """
8
-
9
- # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
10
- # See more details in LICENSE.
11
-
12
- import torch
13
- import torch.nn as nn
14
- import numpy as np
15
- import pytorch_lightning as pl
16
- from torch.optim.lr_scheduler import LambdaLR
17
- from einops import rearrange, repeat
18
- from contextlib import contextmanager
19
- from functools import partial
20
- from tqdm import tqdm
21
- from torchvision.utils import make_grid
22
- from pytorch_lightning.utilities.distributed import rank_zero_only
23
-
24
- from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
25
- from ldm.modules.ema import LitEma
26
- from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
27
- from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
28
- from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
29
- from ldm.models.diffusion.ddim import DDIMSampler
30
-
31
-
32
- __conditioning_keys__ = {'concat': 'c_concat',
33
- 'crossattn': 'c_crossattn',
34
- 'adm': 'y'}
35
-
36
-
37
- def disabled_train(self, mode=True):
38
- """Overwrite model.train with this function to make sure train/eval mode
39
- does not change anymore."""
40
- return self
41
-
42
-
43
- def uniform_on_device(r1, r2, shape, device):
44
- return (r1 - r2) * torch.rand(*shape, device=device) + r2
45
-
46
-
47
- class DDPM(pl.LightningModule):
48
- # classic DDPM with Gaussian diffusion, in image space
49
- def __init__(self,
50
- unet_config,
51
- timesteps=1000,
52
- beta_schedule="linear",
53
- loss_type="l2",
54
- ckpt_path=None,
55
- ignore_keys=[],
56
- load_only_unet=False,
57
- monitor="val/loss",
58
- use_ema=True,
59
- first_stage_key="image",
60
- image_size=256,
61
- channels=3,
62
- log_every_t=100,
63
- clip_denoised=True,
64
- linear_start=1e-4,
65
- linear_end=2e-2,
66
- cosine_s=8e-3,
67
- given_betas=None,
68
- original_elbo_weight=0.,
69
- v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
70
- l_simple_weight=1.,
71
- conditioning_key=None,
72
- parameterization="eps", # all assuming fixed variance schedules
73
- scheduler_config=None,
74
- use_positional_encodings=False,
75
- learn_logvar=False,
76
- logvar_init=0.,
77
- load_ema=True,
78
- ):
79
- super().__init__()
80
- assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
81
- self.parameterization = parameterization
82
- print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
83
- self.cond_stage_model = None
84
- self.clip_denoised = clip_denoised
85
- self.log_every_t = log_every_t
86
- self.first_stage_key = first_stage_key
87
- self.image_size = image_size # try conv?
88
- self.channels = channels
89
- self.use_positional_encodings = use_positional_encodings
90
- self.model = DiffusionWrapper(unet_config, conditioning_key)
91
- count_params(self.model, verbose=True)
92
- self.use_ema = use_ema
93
-
94
- self.use_scheduler = scheduler_config is not None
95
- if self.use_scheduler:
96
- self.scheduler_config = scheduler_config
97
-
98
- self.v_posterior = v_posterior
99
- self.original_elbo_weight = original_elbo_weight
100
- self.l_simple_weight = l_simple_weight
101
-
102
- if monitor is not None:
103
- self.monitor = monitor
104
-
105
- if self.use_ema and load_ema:
106
- self.model_ema = LitEma(self.model)
107
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
108
-
109
- if ckpt_path is not None:
110
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
111
-
112
- # If initialing from EMA-only checkpoint, create EMA model after loading.
113
- if self.use_ema and not load_ema:
114
- self.model_ema = LitEma(self.model)
115
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
116
-
117
- self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
118
- linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
119
-
120
- self.loss_type = loss_type
121
-
122
- self.learn_logvar = learn_logvar
123
- self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
124
- if self.learn_logvar:
125
- self.logvar = nn.Parameter(self.logvar, requires_grad=True)
126
-
127
-
128
- def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
129
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
130
- if exists(given_betas):
131
- betas = given_betas
132
- else:
133
- betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
134
- cosine_s=cosine_s)
135
- alphas = 1. - betas
136
- alphas_cumprod = np.cumprod(alphas, axis=0)
137
- alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
138
-
139
- timesteps, = betas.shape
140
- self.num_timesteps = int(timesteps)
141
- self.linear_start = linear_start
142
- self.linear_end = linear_end
143
- assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
144
-
145
- to_torch = partial(torch.tensor, dtype=torch.float32)
146
-
147
- self.register_buffer('betas', to_torch(betas))
148
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
149
- self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
150
-
151
- # calculations for diffusion q(x_t | x_{t-1}) and others
152
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
153
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
154
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
155
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
156
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
157
-
158
- # calculations for posterior q(x_{t-1} | x_t, x_0)
159
- posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
160
- 1. - alphas_cumprod) + self.v_posterior * betas
161
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
162
- self.register_buffer('posterior_variance', to_torch(posterior_variance))
163
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
164
- self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
165
- self.register_buffer('posterior_mean_coef1', to_torch(
166
- betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
167
- self.register_buffer('posterior_mean_coef2', to_torch(
168
- (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
169
-
170
- if self.parameterization == "eps":
171
- lvlb_weights = self.betas ** 2 / (
172
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
173
- elif self.parameterization == "x0":
174
- lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
175
- else:
176
- raise NotImplementedError("mu not supported")
177
- # TODO how to choose this term
178
- lvlb_weights[0] = lvlb_weights[1]
179
- self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
180
- assert not torch.isnan(self.lvlb_weights).all()
181
-
182
- @contextmanager
183
- def ema_scope(self, context=None):
184
- if self.use_ema:
185
- self.model_ema.store(self.model.parameters())
186
- self.model_ema.copy_to(self.model)
187
- if context is not None:
188
- print(f"{context}: Switched to EMA weights")
189
- try:
190
- yield None
191
- finally:
192
- if self.use_ema:
193
- self.model_ema.restore(self.model.parameters())
194
- if context is not None:
195
- print(f"{context}: Restored training weights")
196
-
197
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
198
- sd = torch.load(path, map_location="cpu")
199
- if "state_dict" in list(sd.keys()):
200
- sd = sd["state_dict"]
201
- keys = list(sd.keys())
202
-
203
- # Our model adds additional channels to the first layer to condition on an input image.
204
- # For the first layer, copy existing channel weights and initialize new channel weights to zero.
205
- input_keys = [
206
- "model.diffusion_model.input_blocks.0.0.weight",
207
- "model_ema.diffusion_modelinput_blocks00weight",
208
- ]
209
-
210
- self_sd = self.state_dict()
211
- for input_key in input_keys:
212
- if input_key not in sd or input_key not in self_sd:
213
- continue
214
-
215
- input_weight = self_sd[input_key]
216
-
217
- if input_weight.size() != sd[input_key].size():
218
- print(f"Manual init: {input_key}")
219
- input_weight.zero_()
220
- input_weight[:, :4, :, :].copy_(sd[input_key])
221
- ignore_keys.append(input_key)
222
-
223
- for k in keys:
224
- for ik in ignore_keys:
225
- if k.startswith(ik):
226
- print("Deleting key {} from state_dict.".format(k))
227
- del sd[k]
228
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
229
- sd, strict=False)
230
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
231
- if len(missing) > 0:
232
- print(f"Missing Keys: {missing}")
233
- if len(unexpected) > 0:
234
- print(f"Unexpected Keys: {unexpected}")
235
-
236
- def q_mean_variance(self, x_start, t):
237
- """
238
- Get the distribution q(x_t | x_0).
239
- :param x_start: the [N x C x ...] tensor of noiseless inputs.
240
- :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
241
- :return: A tuple (mean, variance, log_variance), all of x_start's shape.
242
- """
243
- mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
244
- variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
245
- log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
246
- return mean, variance, log_variance
247
-
248
- def predict_start_from_noise(self, x_t, t, noise):
249
- return (
250
- extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
251
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
252
- )
253
-
254
- def q_posterior(self, x_start, x_t, t):
255
- posterior_mean = (
256
- extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
257
- extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
258
- )
259
- posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
260
- posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
261
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
262
-
263
- def p_mean_variance(self, x, t, clip_denoised: bool):
264
- model_out = self.model(x, t)
265
- if self.parameterization == "eps":
266
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
267
- elif self.parameterization == "x0":
268
- x_recon = model_out
269
- if clip_denoised:
270
- x_recon.clamp_(-1., 1.)
271
-
272
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
273
- return model_mean, posterior_variance, posterior_log_variance
274
-
275
- @torch.no_grad()
276
- def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
277
- b, *_, device = *x.shape, x.device
278
- model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
279
- noise = noise_like(x.shape, device, repeat_noise)
280
- # no noise when t == 0
281
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
282
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
283
-
284
- @torch.no_grad()
285
- def p_sample_loop(self, shape, return_intermediates=False):
286
- device = self.betas.device
287
- b = shape[0]
288
- img = torch.randn(shape, device=device)
289
- intermediates = [img]
290
- for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
291
- img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
292
- clip_denoised=self.clip_denoised)
293
- if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
294
- intermediates.append(img)
295
- if return_intermediates:
296
- return img, intermediates
297
- return img
298
-
299
- @torch.no_grad()
300
- def sample(self, batch_size=16, return_intermediates=False):
301
- image_size = self.image_size
302
- channels = self.channels
303
- return self.p_sample_loop((batch_size, channels, image_size, image_size),
304
- return_intermediates=return_intermediates)
305
-
306
- def q_sample(self, x_start, t, noise=None):
307
- noise = default(noise, lambda: torch.randn_like(x_start))
308
- return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
309
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
310
-
311
- def get_loss(self, pred, target, mean=True):
312
- if self.loss_type == 'l1':
313
- loss = (target - pred).abs()
314
- if mean:
315
- loss = loss.mean()
316
- elif self.loss_type == 'l2':
317
- if mean:
318
- loss = torch.nn.functional.mse_loss(target, pred)
319
- else:
320
- loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
321
- else:
322
- raise NotImplementedError("unknown loss type '{loss_type}'")
323
-
324
- return loss
325
-
326
- def p_losses(self, x_start, t, noise=None):
327
- noise = default(noise, lambda: torch.randn_like(x_start))
328
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
329
- model_out = self.model(x_noisy, t)
330
-
331
- loss_dict = {}
332
- if self.parameterization == "eps":
333
- target = noise
334
- elif self.parameterization == "x0":
335
- target = x_start
336
- else:
337
- raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
338
-
339
- loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
340
-
341
- log_prefix = 'train' if self.training else 'val'
342
-
343
- loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
344
- loss_simple = loss.mean() * self.l_simple_weight
345
-
346
- loss_vlb = (self.lvlb_weights[t] * loss).mean()
347
- loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
348
-
349
- loss = loss_simple + self.original_elbo_weight * loss_vlb
350
-
351
- loss_dict.update({f'{log_prefix}/loss': loss})
352
-
353
- return loss, loss_dict
354
-
355
- def forward(self, x, *args, **kwargs):
356
- # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
357
- # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
358
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
359
- return self.p_losses(x, t, *args, **kwargs)
360
-
361
- def get_input(self, batch, k):
362
- return batch[k]
363
-
364
- def shared_step(self, batch):
365
- x = self.get_input(batch, self.first_stage_key)
366
- loss, loss_dict = self(x)
367
- return loss, loss_dict
368
-
369
- def training_step(self, batch, batch_idx):
370
- loss, loss_dict = self.shared_step(batch)
371
-
372
- self.log_dict(loss_dict, prog_bar=True,
373
- logger=True, on_step=True, on_epoch=True)
374
-
375
- self.log("global_step", self.global_step,
376
- prog_bar=True, logger=True, on_step=True, on_epoch=False)
377
-
378
- if self.use_scheduler:
379
- lr = self.optimizers().param_groups[0]['lr']
380
- self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
381
-
382
- return loss
383
-
384
- @torch.no_grad()
385
- def validation_step(self, batch, batch_idx):
386
- _, loss_dict_no_ema = self.shared_step(batch)
387
- with self.ema_scope():
388
- _, loss_dict_ema = self.shared_step(batch)
389
- loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
390
- self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
391
- self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
392
-
393
- def on_train_batch_end(self, *args, **kwargs):
394
- if self.use_ema:
395
- self.model_ema(self.model)
396
-
397
- def _get_rows_from_list(self, samples):
398
- n_imgs_per_row = len(samples)
399
- denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
400
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
401
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
402
- return denoise_grid
403
-
404
- @torch.no_grad()
405
- def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
406
- log = dict()
407
- x = self.get_input(batch, self.first_stage_key)
408
- N = min(x.shape[0], N)
409
- n_row = min(x.shape[0], n_row)
410
- x = x.to(self.device)[:N]
411
- log["inputs"] = x
412
-
413
- # get diffusion row
414
- diffusion_row = list()
415
- x_start = x[:n_row]
416
-
417
- for t in range(self.num_timesteps):
418
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
419
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
420
- t = t.to(self.device).long()
421
- noise = torch.randn_like(x_start)
422
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
423
- diffusion_row.append(x_noisy)
424
-
425
- log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
426
-
427
- if sample:
428
- # get denoise row
429
- with self.ema_scope("Plotting"):
430
- samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
431
-
432
- log["samples"] = samples
433
- log["denoise_row"] = self._get_rows_from_list(denoise_row)
434
-
435
- if return_keys:
436
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
437
- return log
438
- else:
439
- return {key: log[key] for key in return_keys}
440
- return log
441
-
442
- def configure_optimizers(self):
443
- lr = self.learning_rate
444
- params = list(self.model.parameters())
445
- if self.learn_logvar:
446
- params = params + [self.logvar]
447
- opt = torch.optim.AdamW(params, lr=lr)
448
- return opt
449
-
450
-
451
- class LatentDiffusion(DDPM):
452
- """main class"""
453
- def __init__(self,
454
- first_stage_config,
455
- cond_stage_config,
456
- num_timesteps_cond=None,
457
- cond_stage_key="image",
458
- cond_stage_trainable=False,
459
- concat_mode=True,
460
- cond_stage_forward=None,
461
- conditioning_key=None,
462
- scale_factor=1.0,
463
- scale_by_std=False,
464
- load_ema=True,
465
- *args, **kwargs):
466
- self.num_timesteps_cond = default(num_timesteps_cond, 1)
467
- self.scale_by_std = scale_by_std
468
- assert self.num_timesteps_cond <= kwargs['timesteps']
469
- # for backwards compatibility after implementation of DiffusionWrapper
470
- if conditioning_key is None:
471
- conditioning_key = 'concat' if concat_mode else 'crossattn'
472
- if cond_stage_config == '__is_unconditional__':
473
- conditioning_key = None
474
- ckpt_path = kwargs.pop("ckpt_path", None)
475
- ignore_keys = kwargs.pop("ignore_keys", [])
476
- super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
477
- self.concat_mode = concat_mode
478
- self.cond_stage_trainable = cond_stage_trainable
479
- self.cond_stage_key = cond_stage_key
480
- try:
481
- self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
482
- except:
483
- self.num_downs = 0
484
- if not scale_by_std:
485
- self.scale_factor = scale_factor
486
- else:
487
- self.register_buffer('scale_factor', torch.tensor(scale_factor))
488
- self.instantiate_first_stage(first_stage_config)
489
- self.instantiate_cond_stage(cond_stage_config)
490
- self.cond_stage_forward = cond_stage_forward
491
- self.clip_denoised = False
492
- self.bbox_tokenizer = None
493
-
494
- self.restarted_from_ckpt = False
495
- if ckpt_path is not None:
496
- self.init_from_ckpt(ckpt_path, ignore_keys)
497
- self.restarted_from_ckpt = True
498
-
499
- if self.use_ema and not load_ema:
500
- self.model_ema = LitEma(self.model)
501
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
502
-
503
- def make_cond_schedule(self, ):
504
- self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
505
- ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
506
- self.cond_ids[:self.num_timesteps_cond] = ids
507
-
508
- @rank_zero_only
509
- @torch.no_grad()
510
- def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
511
- # only for very first batch
512
- if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
513
- assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
514
- # set rescale weight to 1./std of encodings
515
- print("### USING STD-RESCALING ###")
516
- x = super().get_input(batch, self.first_stage_key)
517
- x = x.to(self.device)
518
- encoder_posterior = self.encode_first_stage(x)
519
- z = self.get_first_stage_encoding(encoder_posterior).detach()
520
- del self.scale_factor
521
- self.register_buffer('scale_factor', 1. / z.flatten().std())
522
- print(f"setting self.scale_factor to {self.scale_factor}")
523
- print("### USING STD-RESCALING ###")
524
-
525
- def register_schedule(self,
526
- given_betas=None, beta_schedule="linear", timesteps=1000,
527
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
528
- super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
529
-
530
- self.shorten_cond_schedule = self.num_timesteps_cond > 1
531
- if self.shorten_cond_schedule:
532
- self.make_cond_schedule()
533
-
534
- def instantiate_first_stage(self, config):
535
- model = instantiate_from_config(config)
536
- self.first_stage_model = model.eval()
537
- self.first_stage_model.train = disabled_train
538
- for param in self.first_stage_model.parameters():
539
- param.requires_grad = False
540
-
541
- def instantiate_cond_stage(self, config):
542
- if not self.cond_stage_trainable:
543
- if config == "__is_first_stage__":
544
- print("Using first stage also as cond stage.")
545
- self.cond_stage_model = self.first_stage_model
546
- elif config == "__is_unconditional__":
547
- print(f"Training {self.__class__.__name__} as an unconditional model.")
548
- self.cond_stage_model = None
549
- # self.be_unconditional = True
550
- else:
551
- model = instantiate_from_config(config)
552
- self.cond_stage_model = model.eval()
553
- self.cond_stage_model.train = disabled_train
554
- for param in self.cond_stage_model.parameters():
555
- param.requires_grad = False
556
- else:
557
- assert config != '__is_first_stage__'
558
- assert config != '__is_unconditional__'
559
- model = instantiate_from_config(config)
560
- self.cond_stage_model = model
561
-
562
- def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
563
- denoise_row = []
564
- for zd in tqdm(samples, desc=desc):
565
- denoise_row.append(self.decode_first_stage(zd.to(self.device),
566
- force_not_quantize=force_no_decoder_quantization))
567
- n_imgs_per_row = len(denoise_row)
568
- denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
569
- denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
570
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
571
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
572
- return denoise_grid
573
-
574
- def get_first_stage_encoding(self, encoder_posterior):
575
- if isinstance(encoder_posterior, DiagonalGaussianDistribution):
576
- z = encoder_posterior.sample()
577
- elif isinstance(encoder_posterior, torch.Tensor):
578
- z = encoder_posterior
579
- else:
580
- raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
581
- return self.scale_factor * z
582
-
583
- def get_learned_conditioning(self, c):
584
- if self.cond_stage_forward is None:
585
- if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
586
- c = self.cond_stage_model.encode(c)
587
- if isinstance(c, DiagonalGaussianDistribution):
588
- c = c.mode()
589
- else:
590
- c = self.cond_stage_model(c)
591
- else:
592
- assert hasattr(self.cond_stage_model, self.cond_stage_forward)
593
- c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
594
- return c
595
-
596
- def meshgrid(self, h, w):
597
- y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
598
- x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
599
-
600
- arr = torch.cat([y, x], dim=-1)
601
- return arr
602
-
603
- def delta_border(self, h, w):
604
- """
605
- :param h: height
606
- :param w: width
607
- :return: normalized distance to image border,
608
- wtith min distance = 0 at border and max dist = 0.5 at image center
609
- """
610
- lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
611
- arr = self.meshgrid(h, w) / lower_right_corner
612
- dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
613
- dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
614
- edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
615
- return edge_dist
616
-
617
- def get_weighting(self, h, w, Ly, Lx, device):
618
- weighting = self.delta_border(h, w)
619
- weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
620
- self.split_input_params["clip_max_weight"], )
621
- weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
622
-
623
- if self.split_input_params["tie_braker"]:
624
- L_weighting = self.delta_border(Ly, Lx)
625
- L_weighting = torch.clip(L_weighting,
626
- self.split_input_params["clip_min_tie_weight"],
627
- self.split_input_params["clip_max_tie_weight"])
628
-
629
- L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
630
- weighting = weighting * L_weighting
631
- return weighting
632
-
633
- def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
634
- """
635
- :param x: img of size (bs, c, h, w)
636
- :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
637
- """
638
- bs, nc, h, w = x.shape
639
-
640
- # number of crops in image
641
- Ly = (h - kernel_size[0]) // stride[0] + 1
642
- Lx = (w - kernel_size[1]) // stride[1] + 1
643
-
644
- if uf == 1 and df == 1:
645
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
646
- unfold = torch.nn.Unfold(**fold_params)
647
-
648
- fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
649
-
650
- weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
651
- normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
652
- weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
653
-
654
- elif uf > 1 and df == 1:
655
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
656
- unfold = torch.nn.Unfold(**fold_params)
657
-
658
- fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
659
- dilation=1, padding=0,
660
- stride=(stride[0] * uf, stride[1] * uf))
661
- fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
662
-
663
- weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
664
- normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
665
- weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
666
-
667
- elif df > 1 and uf == 1:
668
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
669
- unfold = torch.nn.Unfold(**fold_params)
670
-
671
- fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
672
- dilation=1, padding=0,
673
- stride=(stride[0] // df, stride[1] // df))
674
- fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
675
-
676
- weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
677
- normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
678
- weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
679
-
680
- else:
681
- raise NotImplementedError
682
-
683
- return fold, unfold, normalization, weighting
684
-
685
- @torch.no_grad()
686
- def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
687
- cond_key=None, return_original_cond=False, bs=None, uncond=0.05):
688
- x = super().get_input(batch, k)
689
- if bs is not None:
690
- x = x[:bs]
691
- x = x.to(self.device)
692
- encoder_posterior = self.encode_first_stage(x)
693
- z = self.get_first_stage_encoding(encoder_posterior).detach()
694
- cond_key = cond_key or self.cond_stage_key
695
- xc = super().get_input(batch, cond_key)
696
- if bs is not None:
697
- xc["c_crossattn"] = xc["c_crossattn"][:bs]
698
- xc["c_concat"] = xc["c_concat"][:bs]
699
- cond = {}
700
-
701
- # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%.
702
- random = torch.rand(x.size(0), device=x.device)
703
- prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1")
704
- input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1")
705
-
706
- null_prompt = self.get_learned_conditioning([""])
707
- cond["c_crossattn"] = [torch.where(prompt_mask, null_prompt, self.get_learned_conditioning(xc["c_crossattn"]).detach())]
708
- cond["c_concat"] = [input_mask * self.encode_first_stage((xc["c_concat"].to(self.device))).mode().detach()]
709
-
710
- out = [z, cond]
711
- if return_first_stage_outputs:
712
- xrec = self.decode_first_stage(z)
713
- out.extend([x, xrec])
714
- if return_original_cond:
715
- out.append(xc)
716
- return out
717
-
718
- @torch.no_grad()
719
- def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
720
- if predict_cids:
721
- if z.dim() == 4:
722
- z = torch.argmax(z.exp(), dim=1).long()
723
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
724
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
725
-
726
- z = 1. / self.scale_factor * z
727
-
728
- if hasattr(self, "split_input_params"):
729
- if self.split_input_params["patch_distributed_vq"]:
730
- ks = self.split_input_params["ks"] # eg. (128, 128)
731
- stride = self.split_input_params["stride"] # eg. (64, 64)
732
- uf = self.split_input_params["vqf"]
733
- bs, nc, h, w = z.shape
734
- if ks[0] > h or ks[1] > w:
735
- ks = (min(ks[0], h), min(ks[1], w))
736
- print("reducing Kernel")
737
-
738
- if stride[0] > h or stride[1] > w:
739
- stride = (min(stride[0], h), min(stride[1], w))
740
- print("reducing stride")
741
-
742
- fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
743
-
744
- z = unfold(z) # (bn, nc * prod(**ks), L)
745
- # 1. Reshape to img shape
746
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
747
-
748
- # 2. apply model loop over last dim
749
- if isinstance(self.first_stage_model, VQModelInterface):
750
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
751
- force_not_quantize=predict_cids or force_not_quantize)
752
- for i in range(z.shape[-1])]
753
- else:
754
-
755
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
756
- for i in range(z.shape[-1])]
757
-
758
- o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
759
- o = o * weighting
760
- # Reverse 1. reshape to img shape
761
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
762
- # stitch crops together
763
- decoded = fold(o)
764
- decoded = decoded / normalization # norm is shape (1, 1, h, w)
765
- return decoded
766
- else:
767
- if isinstance(self.first_stage_model, VQModelInterface):
768
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
769
- else:
770
- return self.first_stage_model.decode(z)
771
-
772
- else:
773
- if isinstance(self.first_stage_model, VQModelInterface):
774
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
775
- else:
776
- return self.first_stage_model.decode(z)
777
-
778
- # same as above but without decorator
779
- def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
780
- if predict_cids:
781
- if z.dim() == 4:
782
- z = torch.argmax(z.exp(), dim=1).long()
783
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
784
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
785
-
786
- z = 1. / self.scale_factor * z
787
-
788
- if hasattr(self, "split_input_params"):
789
- if self.split_input_params["patch_distributed_vq"]:
790
- ks = self.split_input_params["ks"] # eg. (128, 128)
791
- stride = self.split_input_params["stride"] # eg. (64, 64)
792
- uf = self.split_input_params["vqf"]
793
- bs, nc, h, w = z.shape
794
- if ks[0] > h or ks[1] > w:
795
- ks = (min(ks[0], h), min(ks[1], w))
796
- print("reducing Kernel")
797
-
798
- if stride[0] > h or stride[1] > w:
799
- stride = (min(stride[0], h), min(stride[1], w))
800
- print("reducing stride")
801
-
802
- fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
803
-
804
- z = unfold(z) # (bn, nc * prod(**ks), L)
805
- # 1. Reshape to img shape
806
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
807
-
808
- # 2. apply model loop over last dim
809
- if isinstance(self.first_stage_model, VQModelInterface):
810
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
811
- force_not_quantize=predict_cids or force_not_quantize)
812
- for i in range(z.shape[-1])]
813
- else:
814
-
815
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
816
- for i in range(z.shape[-1])]
817
-
818
- o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
819
- o = o * weighting
820
- # Reverse 1. reshape to img shape
821
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
822
- # stitch crops together
823
- decoded = fold(o)
824
- decoded = decoded / normalization # norm is shape (1, 1, h, w)
825
- return decoded
826
- else:
827
- if isinstance(self.first_stage_model, VQModelInterface):
828
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
829
- else:
830
- return self.first_stage_model.decode(z)
831
-
832
- else:
833
- if isinstance(self.first_stage_model, VQModelInterface):
834
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
835
- else:
836
- return self.first_stage_model.decode(z)
837
-
838
- @torch.no_grad()
839
- def encode_first_stage(self, x):
840
- if hasattr(self, "split_input_params"):
841
- if self.split_input_params["patch_distributed_vq"]:
842
- ks = self.split_input_params["ks"] # eg. (128, 128)
843
- stride = self.split_input_params["stride"] # eg. (64, 64)
844
- df = self.split_input_params["vqf"]
845
- self.split_input_params['original_image_size'] = x.shape[-2:]
846
- bs, nc, h, w = x.shape
847
- if ks[0] > h or ks[1] > w:
848
- ks = (min(ks[0], h), min(ks[1], w))
849
- print("reducing Kernel")
850
-
851
- if stride[0] > h or stride[1] > w:
852
- stride = (min(stride[0], h), min(stride[1], w))
853
- print("reducing stride")
854
-
855
- fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
856
- z = unfold(x) # (bn, nc * prod(**ks), L)
857
- # Reshape to img shape
858
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
859
-
860
- output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
861
- for i in range(z.shape[-1])]
862
-
863
- o = torch.stack(output_list, axis=-1)
864
- o = o * weighting
865
-
866
- # Reverse reshape to img shape
867
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
868
- # stitch crops together
869
- decoded = fold(o)
870
- decoded = decoded / normalization
871
- return decoded
872
-
873
- else:
874
- return self.first_stage_model.encode(x)
875
- else:
876
- return self.first_stage_model.encode(x)
877
-
878
- def shared_step(self, batch, **kwargs):
879
- x, c = self.get_input(batch, self.first_stage_key)
880
- loss = self(x, c)
881
- return loss
882
-
883
- def forward(self, x, c, *args, **kwargs):
884
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
885
- if self.model.conditioning_key is not None:
886
- assert c is not None
887
- if self.cond_stage_trainable:
888
- c = self.get_learned_conditioning(c)
889
- if self.shorten_cond_schedule: # TODO: drop this option
890
- tc = self.cond_ids[t].to(self.device)
891
- c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
892
- return self.p_losses(x, c, t, *args, **kwargs)
893
-
894
- def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
895
- def rescale_bbox(bbox):
896
- x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
897
- y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
898
- w = min(bbox[2] / crop_coordinates[2], 1 - x0)
899
- h = min(bbox[3] / crop_coordinates[3], 1 - y0)
900
- return x0, y0, w, h
901
-
902
- return [rescale_bbox(b) for b in bboxes]
903
-
904
- def apply_model(self, x_noisy, t, cond, return_ids=False):
905
-
906
- if isinstance(cond, dict):
907
- # hybrid case, cond is exptected to be a dict
908
- pass
909
- else:
910
- if not isinstance(cond, list):
911
- cond = [cond]
912
- key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
913
- cond = {key: cond}
914
-
915
- if hasattr(self, "split_input_params"):
916
- assert len(cond) == 1 # todo can only deal with one conditioning atm
917
- assert not return_ids
918
- ks = self.split_input_params["ks"] # eg. (128, 128)
919
- stride = self.split_input_params["stride"] # eg. (64, 64)
920
-
921
- h, w = x_noisy.shape[-2:]
922
-
923
- fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
924
-
925
- z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
926
- # Reshape to img shape
927
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
928
- z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
929
-
930
- if self.cond_stage_key in ["image", "LR_image", "segmentation",
931
- 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
932
- c_key = next(iter(cond.keys())) # get key
933
- c = next(iter(cond.values())) # get value
934
- assert (len(c) == 1) # todo extend to list with more than one elem
935
- c = c[0] # get element
936
-
937
- c = unfold(c)
938
- c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
939
-
940
- cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
941
-
942
- elif self.cond_stage_key == 'coordinates_bbox':
943
- assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
944
-
945
- # assuming padding of unfold is always 0 and its dilation is always 1
946
- n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
947
- full_img_h, full_img_w = self.split_input_params['original_image_size']
948
- # as we are operating on latents, we need the factor from the original image size to the
949
- # spatial latent size to properly rescale the crops for regenerating the bbox annotations
950
- num_downs = self.first_stage_model.encoder.num_resolutions - 1
951
- rescale_latent = 2 ** (num_downs)
952
-
953
- # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
954
- # need to rescale the tl patch coordinates to be in between (0,1)
955
- tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
956
- rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
957
- for patch_nr in range(z.shape[-1])]
958
-
959
- # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
960
- patch_limits = [(x_tl, y_tl,
961
- rescale_latent * ks[0] / full_img_w,
962
- rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
963
- # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
964
-
965
- # tokenize crop coordinates for the bounding boxes of the respective patches
966
- patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
967
- for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
968
- print(patch_limits_tknzd[0].shape)
969
- # cut tknzd crop position from conditioning
970
- assert isinstance(cond, dict), 'cond must be dict to be fed into model'
971
- cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
972
- print(cut_cond.shape)
973
-
974
- adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
975
- adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
976
- print(adapted_cond.shape)
977
- adapted_cond = self.get_learned_conditioning(adapted_cond)
978
- print(adapted_cond.shape)
979
- adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
980
- print(adapted_cond.shape)
981
-
982
- cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
983
-
984
- else:
985
- cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
986
-
987
- # apply model by loop over crops
988
- output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
989
- assert not isinstance(output_list[0],
990
- tuple) # todo cant deal with multiple model outputs check this never happens
991
-
992
- o = torch.stack(output_list, axis=-1)
993
- o = o * weighting
994
- # Reverse reshape to img shape
995
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
996
- # stitch crops together
997
- x_recon = fold(o) / normalization
998
-
999
- else:
1000
- x_recon = self.model(x_noisy, t, **cond)
1001
-
1002
- if isinstance(x_recon, tuple) and not return_ids:
1003
- return x_recon[0]
1004
- else:
1005
- return x_recon
1006
-
1007
- def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
1008
- return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
1009
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
1010
-
1011
- def _prior_bpd(self, x_start):
1012
- """
1013
- Get the prior KL term for the variational lower-bound, measured in
1014
- bits-per-dim.
1015
- This term can't be optimized, as it only depends on the encoder.
1016
- :param x_start: the [N x C x ...] tensor of inputs.
1017
- :return: a batch of [N] KL values (in bits), one per batch element.
1018
- """
1019
- batch_size = x_start.shape[0]
1020
- t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1021
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1022
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
1023
- return mean_flat(kl_prior) / np.log(2.0)
1024
-
1025
- def p_losses(self, x_start, cond, t, noise=None):
1026
- noise = default(noise, lambda: torch.randn_like(x_start))
1027
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
1028
- model_output = self.apply_model(x_noisy, t, cond)
1029
-
1030
- loss_dict = {}
1031
- prefix = 'train' if self.training else 'val'
1032
-
1033
- if self.parameterization == "x0":
1034
- target = x_start
1035
- elif self.parameterization == "eps":
1036
- target = noise
1037
- else:
1038
- raise NotImplementedError()
1039
-
1040
- loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
1041
- loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
1042
-
1043
- # logvar_t = self.logvar[t].to(self.device)
1044
- # loss = loss_simple / torch.exp(logvar_t) + logvar_t
1045
- self.logvar = self.logvar.to(self.device)
1046
- logvar_t = self.logvar[t]
1047
- loss = loss_simple / torch.exp(logvar_t) + logvar_t
1048
- # loss = loss_simple / torch.exp(self.logvar) + self.logvar
1049
- if self.learn_logvar:
1050
- loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
1051
- loss_dict.update({'logvar': self.logvar.data.mean()})
1052
-
1053
- loss = self.l_simple_weight * loss.mean()
1054
-
1055
- loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
1056
- loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
1057
- loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
1058
- loss += (self.original_elbo_weight * loss_vlb)
1059
- loss_dict.update({f'{prefix}/loss': loss})
1060
-
1061
- return loss, loss_dict
1062
-
1063
- def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
1064
- return_x0=False, score_corrector=None, corrector_kwargs=None):
1065
- t_in = t
1066
- model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
1067
-
1068
- if score_corrector is not None:
1069
- assert self.parameterization == "eps"
1070
- model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
1071
-
1072
- if return_codebook_ids:
1073
- model_out, logits = model_out
1074
-
1075
- if self.parameterization == "eps":
1076
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
1077
- elif self.parameterization == "x0":
1078
- x_recon = model_out
1079
- else:
1080
- raise NotImplementedError()
1081
-
1082
- if clip_denoised:
1083
- x_recon.clamp_(-1., 1.)
1084
- if quantize_denoised:
1085
- x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1086
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
1087
- if return_codebook_ids:
1088
- return model_mean, posterior_variance, posterior_log_variance, logits
1089
- elif return_x0:
1090
- return model_mean, posterior_variance, posterior_log_variance, x_recon
1091
- else:
1092
- return model_mean, posterior_variance, posterior_log_variance
1093
-
1094
- @torch.no_grad()
1095
- def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
1096
- return_codebook_ids=False, quantize_denoised=False, return_x0=False,
1097
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
1098
- b, *_, device = *x.shape, x.device
1099
- outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
1100
- return_codebook_ids=return_codebook_ids,
1101
- quantize_denoised=quantize_denoised,
1102
- return_x0=return_x0,
1103
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1104
- if return_codebook_ids:
1105
- raise DeprecationWarning("Support dropped.")
1106
- model_mean, _, model_log_variance, logits = outputs
1107
- elif return_x0:
1108
- model_mean, _, model_log_variance, x0 = outputs
1109
- else:
1110
- model_mean, _, model_log_variance = outputs
1111
-
1112
- noise = noise_like(x.shape, device, repeat_noise) * temperature
1113
- if noise_dropout > 0.:
1114
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1115
- # no noise when t == 0
1116
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1117
-
1118
- if return_codebook_ids:
1119
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1120
- if return_x0:
1121
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
1122
- else:
1123
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1124
-
1125
- @torch.no_grad()
1126
- def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
1127
- img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
1128
- score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
1129
- log_every_t=None):
1130
- if not log_every_t:
1131
- log_every_t = self.log_every_t
1132
- timesteps = self.num_timesteps
1133
- if batch_size is not None:
1134
- b = batch_size if batch_size is not None else shape[0]
1135
- shape = [batch_size] + list(shape)
1136
- else:
1137
- b = batch_size = shape[0]
1138
- if x_T is None:
1139
- img = torch.randn(shape, device=self.device)
1140
- else:
1141
- img = x_T
1142
- intermediates = []
1143
- if cond is not None:
1144
- if isinstance(cond, dict):
1145
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1146
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1147
- else:
1148
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1149
-
1150
- if start_T is not None:
1151
- timesteps = min(timesteps, start_T)
1152
- iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1153
- total=timesteps) if verbose else reversed(
1154
- range(0, timesteps))
1155
- if type(temperature) == float:
1156
- temperature = [temperature] * timesteps
1157
-
1158
- for i in iterator:
1159
- ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1160
- if self.shorten_cond_schedule:
1161
- assert self.model.conditioning_key != 'hybrid'
1162
- tc = self.cond_ids[ts].to(cond.device)
1163
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1164
-
1165
- img, x0_partial = self.p_sample(img, cond, ts,
1166
- clip_denoised=self.clip_denoised,
1167
- quantize_denoised=quantize_denoised, return_x0=True,
1168
- temperature=temperature[i], noise_dropout=noise_dropout,
1169
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1170
- if mask is not None:
1171
- assert x0 is not None
1172
- img_orig = self.q_sample(x0, ts)
1173
- img = img_orig * mask + (1. - mask) * img
1174
-
1175
- if i % log_every_t == 0 or i == timesteps - 1:
1176
- intermediates.append(x0_partial)
1177
- if callback: callback(i)
1178
- if img_callback: img_callback(img, i)
1179
- return img, intermediates
1180
-
1181
- @torch.no_grad()
1182
- def p_sample_loop(self, cond, shape, return_intermediates=False,
1183
- x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1184
- mask=None, x0=None, img_callback=None, start_T=None,
1185
- log_every_t=None):
1186
-
1187
- if not log_every_t:
1188
- log_every_t = self.log_every_t
1189
- device = self.betas.device
1190
- b = shape[0]
1191
- if x_T is None:
1192
- img = torch.randn(shape, device=device)
1193
- else:
1194
- img = x_T
1195
-
1196
- intermediates = [img]
1197
- if timesteps is None:
1198
- timesteps = self.num_timesteps
1199
-
1200
- if start_T is not None:
1201
- timesteps = min(timesteps, start_T)
1202
- iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1203
- range(0, timesteps))
1204
-
1205
- if mask is not None:
1206
- assert x0 is not None
1207
- assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1208
-
1209
- for i in iterator:
1210
- ts = torch.full((b,), i, device=device, dtype=torch.long)
1211
- if self.shorten_cond_schedule:
1212
- assert self.model.conditioning_key != 'hybrid'
1213
- tc = self.cond_ids[ts].to(cond.device)
1214
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1215
-
1216
- img = self.p_sample(img, cond, ts,
1217
- clip_denoised=self.clip_denoised,
1218
- quantize_denoised=quantize_denoised)
1219
- if mask is not None:
1220
- img_orig = self.q_sample(x0, ts)
1221
- img = img_orig * mask + (1. - mask) * img
1222
-
1223
- if i % log_every_t == 0 or i == timesteps - 1:
1224
- intermediates.append(img)
1225
- if callback: callback(i)
1226
- if img_callback: img_callback(img, i)
1227
-
1228
- if return_intermediates:
1229
- return img, intermediates
1230
- return img
1231
-
1232
- @torch.no_grad()
1233
- def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1234
- verbose=True, timesteps=None, quantize_denoised=False,
1235
- mask=None, x0=None, shape=None,**kwargs):
1236
- if shape is None:
1237
- shape = (batch_size, self.channels, self.image_size, self.image_size)
1238
- if cond is not None:
1239
- if isinstance(cond, dict):
1240
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1241
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1242
- else:
1243
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1244
- return self.p_sample_loop(cond,
1245
- shape,
1246
- return_intermediates=return_intermediates, x_T=x_T,
1247
- verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1248
- mask=mask, x0=x0)
1249
-
1250
- @torch.no_grad()
1251
- def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
1252
-
1253
- if ddim:
1254
- ddim_sampler = DDIMSampler(self)
1255
- shape = (self.channels, self.image_size, self.image_size)
1256
- samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
1257
- shape,cond,verbose=False,**kwargs)
1258
-
1259
- else:
1260
- samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1261
- return_intermediates=True,**kwargs)
1262
-
1263
- return samples, intermediates
1264
-
1265
-
1266
- @torch.no_grad()
1267
- def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1268
- quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
1269
- plot_diffusion_rows=False, **kwargs):
1270
-
1271
- use_ddim = False
1272
-
1273
- log = dict()
1274
- z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1275
- return_first_stage_outputs=True,
1276
- force_c_encode=True,
1277
- return_original_cond=True,
1278
- bs=N, uncond=0)
1279
- N = min(x.shape[0], N)
1280
- n_row = min(x.shape[0], n_row)
1281
- log["inputs"] = x
1282
- log["reals"] = xc["c_concat"]
1283
- log["reconstruction"] = xrec
1284
- if self.model.conditioning_key is not None:
1285
- if hasattr(self.cond_stage_model, "decode"):
1286
- xc = self.cond_stage_model.decode(c)
1287
- log["conditioning"] = xc
1288
- elif self.cond_stage_key in ["caption"]:
1289
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
1290
- log["conditioning"] = xc
1291
- elif self.cond_stage_key == 'class_label':
1292
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
1293
- log['conditioning'] = xc
1294
- elif isimage(xc):
1295
- log["conditioning"] = xc
1296
- if ismap(xc):
1297
- log["original_conditioning"] = self.to_rgb(xc)
1298
-
1299
- if plot_diffusion_rows:
1300
- # get diffusion row
1301
- diffusion_row = list()
1302
- z_start = z[:n_row]
1303
- for t in range(self.num_timesteps):
1304
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1305
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1306
- t = t.to(self.device).long()
1307
- noise = torch.randn_like(z_start)
1308
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1309
- diffusion_row.append(self.decode_first_stage(z_noisy))
1310
-
1311
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1312
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1313
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1314
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1315
- log["diffusion_row"] = diffusion_grid
1316
-
1317
- if sample:
1318
- # get denoise row
1319
- with self.ema_scope("Plotting"):
1320
- samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1321
- ddim_steps=ddim_steps,eta=ddim_eta)
1322
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1323
- x_samples = self.decode_first_stage(samples)
1324
- log["samples"] = x_samples
1325
- if plot_denoise_rows:
1326
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1327
- log["denoise_row"] = denoise_grid
1328
-
1329
- if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1330
- self.first_stage_model, IdentityFirstStage):
1331
- # also display when quantizing x0 while sampling
1332
- with self.ema_scope("Plotting Quantized Denoised"):
1333
- samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1334
- ddim_steps=ddim_steps,eta=ddim_eta,
1335
- quantize_denoised=True)
1336
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1337
- # quantize_denoised=True)
1338
- x_samples = self.decode_first_stage(samples.to(self.device))
1339
- log["samples_x0_quantized"] = x_samples
1340
-
1341
- if inpaint:
1342
- # make a simple center square
1343
- b, h, w = z.shape[0], z.shape[2], z.shape[3]
1344
- mask = torch.ones(N, h, w).to(self.device)
1345
- # zeros will be filled in
1346
- mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1347
- mask = mask[:, None, ...]
1348
- with self.ema_scope("Plotting Inpaint"):
1349
-
1350
- samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
1351
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1352
- x_samples = self.decode_first_stage(samples.to(self.device))
1353
- log["samples_inpainting"] = x_samples
1354
- log["mask"] = mask
1355
-
1356
- # outpaint
1357
- with self.ema_scope("Plotting Outpaint"):
1358
- samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
1359
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1360
- x_samples = self.decode_first_stage(samples.to(self.device))
1361
- log["samples_outpainting"] = x_samples
1362
-
1363
- if plot_progressive_rows:
1364
- with self.ema_scope("Plotting Progressives"):
1365
- img, progressives = self.progressive_denoising(c,
1366
- shape=(self.channels, self.image_size, self.image_size),
1367
- batch_size=N)
1368
- prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1369
- log["progressive_row"] = prog_row
1370
-
1371
- if return_keys:
1372
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1373
- return log
1374
- else:
1375
- return {key: log[key] for key in return_keys}
1376
- return log
1377
-
1378
- def configure_optimizers(self):
1379
- lr = self.learning_rate
1380
- params = list(self.model.parameters())
1381
- if self.cond_stage_trainable:
1382
- print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1383
- params = params + list(self.cond_stage_model.parameters())
1384
- if self.learn_logvar:
1385
- print('Diffusion model optimizing logvar')
1386
- params.append(self.logvar)
1387
- opt = torch.optim.AdamW(params, lr=lr)
1388
- if self.use_scheduler:
1389
- assert 'target' in self.scheduler_config
1390
- scheduler = instantiate_from_config(self.scheduler_config)
1391
-
1392
- print("Setting up LambdaLR scheduler...")
1393
- scheduler = [
1394
- {
1395
- 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1396
- 'interval': 'step',
1397
- 'frequency': 1
1398
- }]
1399
- return [opt], scheduler
1400
- return opt
1401
-
1402
- @torch.no_grad()
1403
- def to_rgb(self, x):
1404
- x = x.float()
1405
- if not hasattr(self, "colorize"):
1406
- self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1407
- x = nn.functional.conv2d(x, weight=self.colorize)
1408
- x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1409
- return x
1410
-
1411
-
1412
- class DiffusionWrapper(pl.LightningModule):
1413
- def __init__(self, diff_model_config, conditioning_key):
1414
- super().__init__()
1415
- self.diffusion_model = instantiate_from_config(diff_model_config)
1416
- self.conditioning_key = conditioning_key
1417
- assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
1418
-
1419
- def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
1420
- if self.conditioning_key is None:
1421
- out = self.diffusion_model(x, t)
1422
- elif self.conditioning_key == 'concat':
1423
- xc = torch.cat([x] + c_concat, dim=1)
1424
- out = self.diffusion_model(xc, t)
1425
- elif self.conditioning_key == 'crossattn':
1426
- cc = torch.cat(c_crossattn, 1)
1427
- out = self.diffusion_model(x, t, context=cc)
1428
- elif self.conditioning_key == 'hybrid':
1429
- xc = torch.cat([x] + c_concat, dim=1)
1430
- cc = torch.cat(c_crossattn, 1)
1431
- out = self.diffusion_model(xc, t, context=cc)
1432
- elif self.conditioning_key == 'adm':
1433
- cc = c_crossattn[0]
1434
- out = self.diffusion_model(x, t, y=cc)
1435
- else:
1436
- raise NotImplementedError()
1437
-
1438
- return out
1439
-
1440
-
1441
- class Layout2ImgDiffusion(LatentDiffusion):
1442
- # TODO: move all layout-specific hacks to this class
1443
- def __init__(self, cond_stage_key, *args, **kwargs):
1444
- assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
1445
- super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
1446
-
1447
- def log_images(self, batch, N=8, *args, **kwargs):
1448
- logs = super().log_images(batch=batch, N=N, *args, **kwargs)
1449
-
1450
- key = 'train' if self.training else 'validation'
1451
- dset = self.trainer.datamodule.datasets[key]
1452
- mapper = dset.conditional_builders[self.cond_stage_key]
1453
-
1454
- bbox_imgs = []
1455
- map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
1456
- for tknzd_bbox in batch[self.cond_stage_key][:N]:
1457
- bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
1458
- bbox_imgs.append(bboximg)
1459
-
1460
- cond_img = torch.stack(bbox_imgs, dim=0)
1461
- logs['bbox_image'] = cond_img
1462
- return logs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable_diffusion/ldm/models/diffusion/ddpm_pam.py DELETED
@@ -1,1527 +0,0 @@
1
- """
2
- wild mixture of
3
- https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
- https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
- https://github.com/CompVis/taming-transformers
6
- -- merci
7
- """
8
-
9
- # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
10
- # See more details in LICENSE.
11
-
12
- import torch
13
- import torch.nn as nn
14
- import numpy as np
15
- import pytorch_lightning as pl
16
- from torch.optim.lr_scheduler import LambdaLR
17
- from einops import rearrange, repeat
18
- from contextlib import contextmanager
19
- from functools import partial
20
- from tqdm import tqdm
21
- from torchvision.utils import make_grid
22
- from pytorch_lightning.utilities.distributed import rank_zero_only
23
-
24
- from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
25
- from ldm.modules.ema import LitEma
26
- from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
27
- from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
28
- from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
29
- from ldm.models.diffusion.ddim import DDIMSampler
30
-
31
-
32
- __conditioning_keys__ = {'concat': 'c_concat',
33
- 'crossattn': 'c_crossattn',
34
- 'adm': 'y'}
35
-
36
-
37
- def disabled_train(self, mode=True):
38
- """Overwrite model.train with this function to make sure train/eval mode
39
- does not change anymore."""
40
- return self
41
-
42
-
43
- def uniform_on_device(r1, r2, shape, device):
44
- return (r1 - r2) * torch.rand(*shape, device=device) + r2
45
-
46
-
47
- class DDPM(pl.LightningModule):
48
- # classic DDPM with Gaussian diffusion, in image space
49
- def __init__(self,
50
- unet_config,
51
- timesteps=1000,
52
- beta_schedule="linear",
53
- loss_type="l2",
54
- ckpt_path=None,
55
- ignore_keys=[],
56
- load_only_unet=False,
57
- monitor="val/loss",
58
- use_ema=True,
59
- first_stage_key="image",
60
- image_size=256,
61
- channels=3,
62
- log_every_t=100,
63
- clip_denoised=True,
64
- linear_start=1e-4,
65
- linear_end=2e-2,
66
- cosine_s=8e-3,
67
- given_betas=None,
68
- original_elbo_weight=0.,
69
- v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
70
- l_simple_weight=1.,
71
- conditioning_key=None,
72
- parameterization="eps", # all assuming fixed variance schedules
73
- scheduler_config=None,
74
- use_positional_encodings=False,
75
- learn_logvar=False,
76
- logvar_init=0.,
77
- load_ema=True,
78
- ):
79
- super().__init__()
80
- assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
81
- self.parameterization = parameterization
82
- print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
83
- self.cond_stage_model = None
84
- self.clip_denoised = clip_denoised
85
- self.log_every_t = log_every_t
86
- self.first_stage_key = first_stage_key
87
- self.image_size = image_size # try conv?
88
- self.channels = channels
89
- self.use_positional_encodings = use_positional_encodings
90
- self.model = DiffusionWrapper(unet_config, conditioning_key)
91
- count_params(self.model, verbose=True)
92
- self.use_ema = use_ema
93
-
94
- self.use_scheduler = scheduler_config is not None
95
- if self.use_scheduler:
96
- self.scheduler_config = scheduler_config
97
-
98
- self.v_posterior = v_posterior
99
- self.original_elbo_weight = original_elbo_weight
100
- self.l_simple_weight = l_simple_weight
101
-
102
- if monitor is not None:
103
- self.monitor = monitor
104
-
105
- if self.use_ema and load_ema:
106
- self.model_ema = LitEma(self.model)
107
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
108
-
109
- if ckpt_path is not None:
110
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
111
-
112
- # If initialing from EMA-only checkpoint, create EMA model after loading.
113
- if self.use_ema and not load_ema:
114
- self.model_ema = LitEma(self.model)
115
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
116
-
117
- self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
118
- linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
119
-
120
- self.loss_type = loss_type
121
-
122
- self.learn_logvar = learn_logvar
123
- self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
124
- if self.learn_logvar:
125
- self.logvar = nn.Parameter(self.logvar, requires_grad=True)
126
-
127
-
128
- def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
129
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
130
- if exists(given_betas):
131
- betas = given_betas
132
- else:
133
- betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
134
- cosine_s=cosine_s)
135
- alphas = 1. - betas
136
- alphas_cumprod = np.cumprod(alphas, axis=0)
137
- alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
138
-
139
- timesteps, = betas.shape
140
- self.num_timesteps = int(timesteps)
141
- self.linear_start = linear_start
142
- self.linear_end = linear_end
143
- assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
144
-
145
- to_torch = partial(torch.tensor, dtype=torch.float32)
146
-
147
- self.register_buffer('betas', to_torch(betas))
148
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
149
- self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
150
-
151
- # calculations for diffusion q(x_t | x_{t-1}) and others
152
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
153
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
154
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
155
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
156
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
157
-
158
- # calculations for posterior q(x_{t-1} | x_t, x_0)
159
- posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
160
- 1. - alphas_cumprod) + self.v_posterior * betas
161
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
162
- self.register_buffer('posterior_variance', to_torch(posterior_variance))
163
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
164
- self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
165
- self.register_buffer('posterior_mean_coef1', to_torch(
166
- betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
167
- self.register_buffer('posterior_mean_coef2', to_torch(
168
- (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
169
-
170
- if self.parameterization == "eps":
171
- lvlb_weights = self.betas ** 2 / (
172
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
173
- elif self.parameterization == "x0":
174
- lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
175
- else:
176
- raise NotImplementedError("mu not supported")
177
- # TODO how to choose this term
178
- lvlb_weights[0] = lvlb_weights[1]
179
- self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
180
- assert not torch.isnan(self.lvlb_weights).all()
181
-
182
- @contextmanager
183
- def ema_scope(self, context=None):
184
- if self.use_ema:
185
- self.model_ema.store(self.model.parameters())
186
- self.model_ema.copy_to(self.model)
187
- if context is not None:
188
- print(f"{context}: Switched to EMA weights")
189
- try:
190
- yield None
191
- finally:
192
- if self.use_ema:
193
- self.model_ema.restore(self.model.parameters())
194
- if context is not None:
195
- print(f"{context}: Restored training weights")
196
-
197
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
198
- sd = torch.load(path, map_location="cpu")
199
- if "state_dict" in list(sd.keys()):
200
- sd = sd["state_dict"]
201
- keys = list(sd.keys())
202
-
203
- # Our model adds additional channels to the first layer to condition on an input image.
204
- # For the first layer, copy existing channel weights and initialize new channel weights to zero.
205
- input_keys = [
206
- "model.diffusion_model.input_blocks.0.0.weight",
207
- "model_ema.diffusion_modelinput_blocks00weight",
208
- ]
209
-
210
- branch_1_keys = [
211
- "model.diffusion_model.input_blocks_branch_1",
212
- "model.diffusion_model.output_blocks_branch_1",
213
- "model.diffusion_model.out_branch_1",
214
- "model_ema.diffusion_modelinput_blocks_branch_100weight",
215
- "model_ema.diffusion_modelout_branch_10weight",
216
- "model_ema.diffusion_modelout_branch_12weight",
217
-
218
- ]
219
- ignore_keys += branch_1_keys
220
- self_sd = self.state_dict()
221
-
222
-
223
- for input_key in input_keys:
224
- if input_key not in sd or input_key not in self_sd:
225
- continue
226
-
227
- input_weight = self_sd[input_key]
228
-
229
- if input_weight.size() != sd[input_key].size():
230
- print(f"Manual init: {input_key}")
231
- input_weight.zero_()
232
- input_weight[:, :4, :, :].copy_(sd[input_key])
233
- ignore_keys.append(input_key)
234
-
235
-
236
- for branch_1_key in branch_1_keys:
237
- start_with_branch_1_keys = [k for k in self_sd if k.startswith(branch_1_key)]
238
- main_keys = [k.replace("_branch_1", "") for k in start_with_branch_1_keys]
239
-
240
- for start_with_branch_1_key, main_key in zip(start_with_branch_1_keys, main_keys):
241
- if start_with_branch_1_key not in self_sd or main_key not in sd:
242
- continue
243
-
244
- branch_1_weight = self_sd[start_with_branch_1_key]
245
- if branch_1_weight.size() != sd[main_key].size():
246
- print(f"Manual init: {start_with_branch_1_key}")
247
- branch_1_weight.zero_()
248
- branch_1_weight[:, :4, :, :].copy_(sd[main_key])
249
- ignore_keys.append(start_with_branch_1_key)
250
- else:
251
- branch_1_weight.zero_()
252
- branch_1_weight.copy_(sd[main_key])
253
- ignore_keys.append(start_with_branch_1_key)
254
-
255
- for k in keys:
256
- for ik in ignore_keys:
257
- if k.startswith(ik):
258
- print("Deleting key {} from state_dict.".format(k))
259
- del sd[k]
260
-
261
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
262
- sd, strict=False)
263
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
264
- if len(missing) > 0:
265
- print(f"Missing Keys: {missing}")
266
- if len(unexpected) > 0:
267
- print(f"Unexpected Keys: {unexpected}")
268
-
269
-
270
- def q_mean_variance(self, x_start, t):
271
- """
272
- Get the distribution q(x_t | x_0).
273
- :param x_start: the [N x C x ...] tensor of noiseless inputs.
274
- :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
275
- :return: A tuple (mean, variance, log_variance), all of x_start's shape.
276
- """
277
- mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
278
- variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
279
- log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
280
- return mean, variance, log_variance
281
-
282
- def predict_start_from_noise(self, x_t, t, noise):
283
- return (
284
- extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
285
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
286
- )
287
-
288
- def q_posterior(self, x_start, x_t, t):
289
- posterior_mean = (
290
- extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
291
- extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
292
- )
293
- posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
294
- posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
295
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
296
-
297
- def p_mean_variance(self, x, t, clip_denoised: bool):
298
- model_out = self.model(x, t)
299
- if self.parameterization == "eps":
300
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
301
- elif self.parameterization == "x0":
302
- x_recon = model_out
303
- if clip_denoised:
304
- x_recon.clamp_(-1., 1.)
305
-
306
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
307
- return model_mean, posterior_variance, posterior_log_variance
308
-
309
- @torch.no_grad()
310
- def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
311
- b, *_, device = *x.shape, x.device
312
- model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
313
- noise = noise_like(x.shape, device, repeat_noise)
314
- # no noise when t == 0
315
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
316
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
317
-
318
- @torch.no_grad()
319
- def p_sample_loop(self, shape, return_intermediates=False):
320
- device = self.betas.device
321
- b = shape[0]
322
- img = torch.randn(shape, device=device)
323
- intermediates = [img]
324
- for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
325
- img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
326
- clip_denoised=self.clip_denoised)
327
- if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
328
- intermediates.append(img)
329
- if return_intermediates:
330
- return img, intermediates
331
- return img
332
-
333
- @torch.no_grad()
334
- def sample(self, batch_size=16, return_intermediates=False):
335
- image_size = self.image_size
336
- channels = self.channels
337
- return self.p_sample_loop((batch_size, channels, image_size, image_size),
338
- return_intermediates=return_intermediates)
339
-
340
- def q_sample(self, x_start, t, noise=None):
341
- noise = default(noise, lambda: torch.randn_like(x_start))
342
- return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
343
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
344
-
345
- def get_loss(self, pred, target, mean=True):
346
- if self.loss_type == 'l1':
347
- loss = (target - pred).abs()
348
- if mean:
349
- loss = loss.mean()
350
- elif self.loss_type == 'l2':
351
- if mean:
352
- loss = torch.nn.functional.mse_loss(target, pred)
353
- else:
354
- loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
355
- else:
356
- raise NotImplementedError("unknown loss type '{loss_type}'")
357
-
358
- return loss
359
-
360
- def p_losses(self, x_start, t, noise=None):
361
- noise = default(noise, lambda: torch.randn_like(x_start))
362
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
363
- model_out = self.model(x_noisy, t)
364
-
365
- loss_dict = {}
366
- if self.parameterization == "eps":
367
- target = noise
368
- elif self.parameterization == "x0":
369
- target = x_start
370
- else:
371
- raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
372
-
373
- loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
374
-
375
- log_prefix = 'train' if self.training else 'val'
376
-
377
- loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
378
- loss_simple = loss.mean() * self.l_simple_weight
379
-
380
- loss_vlb = (self.lvlb_weights[t] * loss).mean()
381
- loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
382
-
383
- loss = loss_simple + self.original_elbo_weight * loss_vlb
384
-
385
- loss_dict.update({f'{log_prefix}/loss': loss})
386
-
387
- return loss, loss_dict
388
-
389
- def forward(self, x, *args, **kwargs):
390
- # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
391
- # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
392
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
393
- return self.p_losses(x, t, *args, **kwargs)
394
-
395
- def get_input(self, batch, k):
396
- return batch[k]
397
-
398
- def shared_step(self, batch):
399
- x = self.get_input(batch, self.first_stage_key)
400
- loss, loss_dict = self(x)
401
- return loss, loss_dict
402
-
403
- def training_step(self, batch, batch_idx):
404
- loss, loss_dict = self.shared_step(batch)
405
-
406
- self.log_dict(loss_dict, prog_bar=True,
407
- logger=True, on_step=True, on_epoch=True)
408
-
409
- self.log("global_step", self.global_step,
410
- prog_bar=True, logger=True, on_step=True, on_epoch=False)
411
-
412
- if self.use_scheduler:
413
- lr = self.optimizers().param_groups[0]['lr']
414
- self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
415
-
416
- return loss
417
-
418
- @torch.no_grad()
419
- def validation_step(self, batch, batch_idx):
420
- _, loss_dict_no_ema = self.shared_step(batch)
421
- with self.ema_scope():
422
- _, loss_dict_ema = self.shared_step(batch)
423
- loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
424
- self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
425
- self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
426
-
427
- def on_train_batch_end(self, *args, **kwargs):
428
- if self.use_ema:
429
- self.model_ema(self.model)
430
-
431
- def _get_rows_from_list(self, samples):
432
- n_imgs_per_row = len(samples)
433
- denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
434
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
435
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
436
- return denoise_grid
437
-
438
- @torch.no_grad()
439
- def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
440
- log = dict()
441
- x = self.get_input(batch, self.first_stage_key)
442
- N = min(x.shape[0], N)
443
- n_row = min(x.shape[0], n_row)
444
- x = x.to(self.device)[:N]
445
- log["inputs"] = x
446
-
447
- # get diffusion row
448
- diffusion_row = list()
449
- x_start = x[:n_row]
450
-
451
- for t in range(self.num_timesteps):
452
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
453
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
454
- t = t.to(self.device).long()
455
- noise = torch.randn_like(x_start)
456
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
457
- diffusion_row.append(x_noisy)
458
-
459
- log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
460
-
461
- if sample:
462
- # get denoise row
463
- with self.ema_scope("Plotting"):
464
- samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
465
-
466
- log["samples"] = samples
467
- log["denoise_row"] = self._get_rows_from_list(denoise_row)
468
-
469
- if return_keys:
470
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
471
- return log
472
- else:
473
- return {key: log[key] for key in return_keys}
474
- return log
475
-
476
- def configure_optimizers(self):
477
- lr = self.learning_rate
478
- params = list(self.model.parameters())
479
- if self.learn_logvar:
480
- params = params + [self.logvar]
481
- opt = torch.optim.AdamW(params, lr=lr)
482
- return opt
483
-
484
-
485
- class LatentDiffusion(DDPM):
486
- """main class"""
487
- def __init__(self,
488
- first_stage_config,
489
- cond_stage_config,
490
- num_timesteps_cond=None,
491
- cond_stage_key="image",
492
- cond_stage_trainable=False,
493
- concat_mode=True,
494
- cond_stage_forward=None,
495
- conditioning_key=None,
496
- scale_factor=1.0,
497
- scale_by_std=False,
498
- load_ema=True,
499
- *args, **kwargs):
500
- self.num_timesteps_cond = default(num_timesteps_cond, 1)
501
- self.scale_by_std = scale_by_std
502
- assert self.num_timesteps_cond <= kwargs['timesteps']
503
- # for backwards compatibility after implementation of DiffusionWrapper
504
- if conditioning_key is None:
505
- conditioning_key = 'concat' if concat_mode else 'crossattn'
506
- if cond_stage_config == '__is_unconditional__':
507
- conditioning_key = None
508
- ckpt_path = kwargs.pop("ckpt_path", None)
509
- ignore_keys = kwargs.pop("ignore_keys", [])
510
- super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
511
- self.concat_mode = concat_mode
512
- self.cond_stage_trainable = cond_stage_trainable
513
- self.cond_stage_key = cond_stage_key
514
- try:
515
- self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
516
- except:
517
- self.num_downs = 0
518
- if not scale_by_std:
519
- self.scale_factor = scale_factor
520
- else:
521
- self.register_buffer('scale_factor', torch.tensor(scale_factor))
522
- self.instantiate_first_stage(first_stage_config)
523
- self.instantiate_cond_stage(cond_stage_config)
524
- self.cond_stage_forward = cond_stage_forward
525
- self.clip_denoised = False
526
- self.bbox_tokenizer = None
527
-
528
- self.restarted_from_ckpt = False
529
- if ckpt_path is not None:
530
- self.init_from_ckpt(ckpt_path, ignore_keys)
531
- self.restarted_from_ckpt = True
532
-
533
- if self.use_ema and not load_ema:
534
- self.model_ema = LitEma(self.model)
535
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
536
-
537
- def make_cond_schedule(self, ):
538
- self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
539
- ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
540
- self.cond_ids[:self.num_timesteps_cond] = ids
541
-
542
- @rank_zero_only
543
- @torch.no_grad()
544
- def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
545
- # only for very first batch
546
- if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
547
- assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
548
- # set rescale weight to 1./std of encodings
549
- print("### USING STD-RESCALING ###")
550
- x = super().get_input(batch, self.first_stage_key)
551
- x = x.to(self.device)
552
- encoder_posterior = self.encode_first_stage(x)
553
- z = self.get_first_stage_encoding(encoder_posterior).detach()
554
- del self.scale_factor
555
- self.register_buffer('scale_factor', 1. / z.flatten().std())
556
- print(f"setting self.scale_factor to {self.scale_factor}")
557
- print("### USING STD-RESCALING ###")
558
-
559
- def register_schedule(self,
560
- given_betas=None, beta_schedule="linear", timesteps=1000,
561
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
562
- super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
563
-
564
- self.shorten_cond_schedule = self.num_timesteps_cond > 1
565
- if self.shorten_cond_schedule:
566
- self.make_cond_schedule()
567
-
568
- def instantiate_first_stage(self, config):
569
- model = instantiate_from_config(config)
570
- self.first_stage_model = model.eval()
571
- self.first_stage_model.train = disabled_train
572
- for param in self.first_stage_model.parameters():
573
- param.requires_grad = False
574
-
575
- def instantiate_cond_stage(self, config):
576
- if not self.cond_stage_trainable:
577
- if config == "__is_first_stage__":
578
- print("Using first stage also as cond stage.")
579
- self.cond_stage_model = self.first_stage_model
580
- elif config == "__is_unconditional__":
581
- print(f"Training {self.__class__.__name__} as an unconditional model.")
582
- self.cond_stage_model = None
583
- # self.be_unconditional = True
584
- else:
585
- model = instantiate_from_config(config)
586
- self.cond_stage_model = model.eval()
587
- self.cond_stage_model.train = disabled_train
588
- for param in self.cond_stage_model.parameters():
589
- param.requires_grad = False
590
- else:
591
- assert config != '__is_first_stage__'
592
- assert config != '__is_unconditional__'
593
- model = instantiate_from_config(config)
594
- self.cond_stage_model = model
595
-
596
- def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
597
- denoise_row = []
598
- for zd in tqdm(samples, desc=desc):
599
- denoise_row.append(self.decode_first_stage(zd.to(self.device),
600
- force_not_quantize=force_no_decoder_quantization))
601
- n_imgs_per_row = len(denoise_row)
602
- denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
603
- denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
604
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
605
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
606
- return denoise_grid
607
-
608
- def get_first_stage_encoding(self, encoder_posterior):
609
- if isinstance(encoder_posterior, DiagonalGaussianDistribution):
610
- z = encoder_posterior.sample()
611
- elif isinstance(encoder_posterior, torch.Tensor):
612
- z = encoder_posterior
613
- else:
614
- raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
615
- return self.scale_factor * z
616
-
617
- def get_learned_conditioning(self, c):
618
- if self.cond_stage_forward is None:
619
- if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
620
- c = self.cond_stage_model.encode(c)
621
- if isinstance(c, DiagonalGaussianDistribution):
622
- c = c.mode()
623
- else:
624
- c = self.cond_stage_model(c)
625
- else:
626
- assert hasattr(self.cond_stage_model, self.cond_stage_forward)
627
- c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
628
- return c
629
-
630
- def meshgrid(self, h, w):
631
- y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
632
- x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
633
-
634
- arr = torch.cat([y, x], dim=-1)
635
- return arr
636
-
637
- def delta_border(self, h, w):
638
- """
639
- :param h: height
640
- :param w: width
641
- :return: normalized distance to image border,
642
- wtith min distance = 0 at border and max dist = 0.5 at image center
643
- """
644
- lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
645
- arr = self.meshgrid(h, w) / lower_right_corner
646
- dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
647
- dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
648
- edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
649
- return edge_dist
650
-
651
- def get_weighting(self, h, w, Ly, Lx, device):
652
- weighting = self.delta_border(h, w)
653
- weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
654
- self.split_input_params["clip_max_weight"], )
655
- weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
656
-
657
- if self.split_input_params["tie_braker"]:
658
- L_weighting = self.delta_border(Ly, Lx)
659
- L_weighting = torch.clip(L_weighting,
660
- self.split_input_params["clip_min_tie_weight"],
661
- self.split_input_params["clip_max_tie_weight"])
662
-
663
- L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
664
- weighting = weighting * L_weighting
665
- return weighting
666
-
667
- def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
668
- """
669
- :param x: img of size (bs, c, h, w)
670
- :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
671
- """
672
- bs, nc, h, w = x.shape
673
-
674
- # number of crops in image
675
- Ly = (h - kernel_size[0]) // stride[0] + 1
676
- Lx = (w - kernel_size[1]) // stride[1] + 1
677
-
678
- if uf == 1 and df == 1:
679
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
680
- unfold = torch.nn.Unfold(**fold_params)
681
-
682
- fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
683
-
684
- weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
685
- normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
686
- weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
687
-
688
- elif uf > 1 and df == 1:
689
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
690
- unfold = torch.nn.Unfold(**fold_params)
691
-
692
- fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
693
- dilation=1, padding=0,
694
- stride=(stride[0] * uf, stride[1] * uf))
695
- fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
696
-
697
- weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
698
- normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
699
- weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
700
-
701
- elif df > 1 and uf == 1:
702
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
703
- unfold = torch.nn.Unfold(**fold_params)
704
-
705
- fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
706
- dilation=1, padding=0,
707
- stride=(stride[0] // df, stride[1] // df))
708
- fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
709
-
710
- weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
711
- normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
712
- weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
713
-
714
- else:
715
- raise NotImplementedError
716
-
717
- return fold, unfold, normalization, weighting
718
-
719
- @torch.no_grad()
720
- def get_input(self, batch, keys, return_first_stage_outputs=False, force_c_encode=False,
721
- cond_key=None, return_original_cond=False, bs=None, uncond=0.05):
722
- x_0 = super().get_input(batch, keys[0])
723
- x_1 = super().get_input(batch, keys[1])
724
- if bs is not None:
725
- x_0 = x_0[:bs]
726
- x_1 = x_1[:bs]
727
- x_0 = x_0.to(self.device)
728
- x_1 = x_1.to(self.device)
729
- encoder_posterior = self.encode_first_stage(x_0)
730
- z_0 = self.get_first_stage_encoding(encoder_posterior).detach()
731
- z_1 = self.get_first_stage_encoding(self.encode_first_stage(x_1)).detach()
732
- cond_key = cond_key or self.cond_stage_key
733
- xc = super().get_input(batch, cond_key)
734
- if bs is not None:
735
- xc["c_crossattn"] = xc["c_crossattn"][:bs]
736
- xc["c_concat"] = xc["c_concat"][:bs]
737
- cond = {}
738
-
739
- # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%.
740
- random = torch.rand(x_0.size(0), device=x_0.device)
741
- prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1")
742
- input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1")
743
-
744
- null_prompt = self.get_learned_conditioning([""])
745
- cond["c_crossattn"] = [torch.where(prompt_mask, null_prompt, self.get_learned_conditioning(xc["c_crossattn"]).detach())]
746
- cond["c_concat"] = [input_mask * self.encode_first_stage((xc["c_concat"].to(self.device))).mode().detach()]
747
-
748
- out = [z_0, z_1, cond]
749
- if return_first_stage_outputs:
750
- x_0_rec = self.decode_first_stage(z_0)
751
- x_1_rec = self.decode_first_stage(z_1)
752
- out.extend([x_0, x_0_rec, x_1, x_1_rec])
753
- if return_original_cond:
754
- out.append(xc)
755
-
756
- return out
757
-
758
- @torch.no_grad()
759
- def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
760
- if predict_cids:
761
- if z.dim() == 4:
762
- z = torch.argmax(z.exp(), dim=1).long()
763
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
764
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
765
-
766
- z = 1. / self.scale_factor * z
767
-
768
- if hasattr(self, "split_input_params"):
769
- if self.split_input_params["patch_distributed_vq"]:
770
- ks = self.split_input_params["ks"] # eg. (128, 128)
771
- stride = self.split_input_params["stride"] # eg. (64, 64)
772
- uf = self.split_input_params["vqf"]
773
- bs, nc, h, w = z.shape
774
- if ks[0] > h or ks[1] > w:
775
- ks = (min(ks[0], h), min(ks[1], w))
776
- print("reducing Kernel")
777
-
778
- if stride[0] > h or stride[1] > w:
779
- stride = (min(stride[0], h), min(stride[1], w))
780
- print("reducing stride")
781
-
782
- fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
783
-
784
- z = unfold(z) # (bn, nc * prod(**ks), L)
785
- # 1. Reshape to img shape
786
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
787
-
788
- # 2. apply model loop over last dim
789
- if isinstance(self.first_stage_model, VQModelInterface):
790
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
791
- force_not_quantize=predict_cids or force_not_quantize)
792
- for i in range(z.shape[-1])]
793
- else:
794
-
795
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
796
- for i in range(z.shape[-1])]
797
-
798
- o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
799
- o = o * weighting
800
- # Reverse 1. reshape to img shape
801
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
802
- # stitch crops together
803
- decoded = fold(o)
804
- decoded = decoded / normalization # norm is shape (1, 1, h, w)
805
- return decoded
806
- else:
807
- if isinstance(self.first_stage_model, VQModelInterface):
808
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
809
- else:
810
- return self.first_stage_model.decode(z)
811
-
812
- else:
813
- if isinstance(self.first_stage_model, VQModelInterface):
814
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
815
- else:
816
- return self.first_stage_model.decode(z)
817
-
818
- # same as above but without decorator
819
- def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
820
- if predict_cids:
821
- if z.dim() == 4:
822
- z = torch.argmax(z.exp(), dim=1).long()
823
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
824
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
825
-
826
- z = 1. / self.scale_factor * z
827
-
828
- if hasattr(self, "split_input_params"):
829
- if self.split_input_params["patch_distributed_vq"]:
830
- ks = self.split_input_params["ks"] # eg. (128, 128)
831
- stride = self.split_input_params["stride"] # eg. (64, 64)
832
- uf = self.split_input_params["vqf"]
833
- bs, nc, h, w = z.shape
834
- if ks[0] > h or ks[1] > w:
835
- ks = (min(ks[0], h), min(ks[1], w))
836
- print("reducing Kernel")
837
-
838
- if stride[0] > h or stride[1] > w:
839
- stride = (min(stride[0], h), min(stride[1], w))
840
- print("reducing stride")
841
-
842
- fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
843
-
844
- z = unfold(z) # (bn, nc * prod(**ks), L)
845
- # 1. Reshape to img shape
846
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
847
-
848
- # 2. apply model loop over last dim
849
- if isinstance(self.first_stage_model, VQModelInterface):
850
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
851
- force_not_quantize=predict_cids or force_not_quantize)
852
- for i in range(z.shape[-1])]
853
- else:
854
-
855
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
856
- for i in range(z.shape[-1])]
857
-
858
- o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
859
- o = o * weighting
860
- # Reverse 1. reshape to img shape
861
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
862
- # stitch crops together
863
- decoded = fold(o)
864
- decoded = decoded / normalization # norm is shape (1, 1, h, w)
865
- return decoded
866
- else:
867
- if isinstance(self.first_stage_model, VQModelInterface):
868
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
869
- else:
870
- return self.first_stage_model.decode(z)
871
-
872
- else:
873
- if isinstance(self.first_stage_model, VQModelInterface):
874
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
875
- else:
876
- return self.first_stage_model.decode(z)
877
-
878
- @torch.no_grad()
879
- def encode_first_stage(self, x):
880
- if hasattr(self, "split_input_params"):
881
- if self.split_input_params["patch_distributed_vq"]:
882
- ks = self.split_input_params["ks"] # eg. (128, 128)
883
- stride = self.split_input_params["stride"] # eg. (64, 64)
884
- df = self.split_input_params["vqf"]
885
- self.split_input_params['original_image_size'] = x.shape[-2:]
886
- bs, nc, h, w = x.shape
887
- if ks[0] > h or ks[1] > w:
888
- ks = (min(ks[0], h), min(ks[1], w))
889
- print("reducing Kernel")
890
-
891
- if stride[0] > h or stride[1] > w:
892
- stride = (min(stride[0], h), min(stride[1], w))
893
- print("reducing stride")
894
-
895
- fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
896
- z = unfold(x) # (bn, nc * prod(**ks), L)
897
- # Reshape to img shape
898
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
899
-
900
- output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
901
- for i in range(z.shape[-1])]
902
-
903
- o = torch.stack(output_list, axis=-1)
904
- o = o * weighting
905
-
906
- # Reverse reshape to img shape
907
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
908
- # stitch crops together
909
- decoded = fold(o)
910
- decoded = decoded / normalization
911
- return decoded
912
-
913
- else:
914
- return self.first_stage_model.encode(x)
915
- else:
916
- return self.first_stage_model.encode(x)
917
-
918
- def shared_step(self, batch, **kwargs):
919
- x_0, x_1, c = self.get_input(batch, self.first_stage_key)
920
- loss = self(x_0, x_1, c)
921
- return loss
922
-
923
- def forward(self, x_0, x_1, c, *args, **kwargs):
924
- t = torch.randint(0, self.num_timesteps, (x_0.shape[0],), device=self.device).long()
925
- if self.model.conditioning_key is not None:
926
- assert c is not None
927
- # in pix2pix, cond_stage_trainable and short_cond_schedule are false
928
- if self.cond_stage_trainable:
929
- c = self.get_learned_conditioning(c)
930
- if self.shorten_cond_schedule: # TODO: drop this option
931
- tc = self.cond_ids[t].to(self.device)
932
- c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
933
- return self.p_losses(x_0, x_1, c, t, *args, **kwargs)
934
-
935
- def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
936
- def rescale_bbox(bbox):
937
- x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
938
- y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
939
- w = min(bbox[2] / crop_coordinates[2], 1 - x0)
940
- h = min(bbox[3] / crop_coordinates[3], 1 - y0)
941
- return x0, y0, w, h
942
-
943
- return [rescale_bbox(b) for b in bboxes]
944
-
945
- def apply_model(self, x_noisy_0, x_noisy_1, t, cond, return_ids=False):
946
- if isinstance(cond, dict):
947
- # hybrid case, cond is exptected to be a dict
948
- pass
949
- else:
950
- if not isinstance(cond, list):
951
- cond = [cond]
952
- key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
953
- cond = {key: cond}
954
-
955
- if hasattr(self, "split_input_params"):
956
- assert len(cond) == 1 # todo can only deal with one conditioning atm
957
- assert not return_ids
958
- ks = self.split_input_params["ks"] # eg. (128, 128)
959
- stride = self.split_input_params["stride"] # eg. (64, 64)
960
-
961
- h, w = x_noisy.shape[-2:]
962
-
963
- fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
964
-
965
- z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
966
- # Reshape to img shape
967
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
968
- z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
969
-
970
- if self.cond_stage_key in ["image", "LR_image", "segmentation",
971
- 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
972
- c_key = next(iter(cond.keys())) # get key
973
- c = next(iter(cond.values())) # get value
974
- assert (len(c) == 1) # todo extend to list with more than one elem
975
- c = c[0] # get element
976
-
977
- c = unfold(c)
978
- c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
979
-
980
- cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
981
-
982
- elif self.cond_stage_key == 'coordinates_bbox':
983
- assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
984
-
985
- # assuming padding of unfold is always 0 and its dilation is always 1
986
- n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
987
- full_img_h, full_img_w = self.split_input_params['original_image_size']
988
- # as we are operating on latents, we need the factor from the original image size to the
989
- # spatial latent size to properly rescale the crops for regenerating the bbox annotations
990
- num_downs = self.first_stage_model.encoder.num_resolutions - 1
991
- rescale_latent = 2 ** (num_downs)
992
-
993
- # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
994
- # need to rescale the tl patch coordinates to be in between (0,1)
995
- tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
996
- rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
997
- for patch_nr in range(z.shape[-1])]
998
-
999
- # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
1000
- patch_limits = [(x_tl, y_tl,
1001
- rescale_latent * ks[0] / full_img_w,
1002
- rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
1003
- # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
1004
-
1005
- # tokenize crop coordinates for the bounding boxes of the respective patches
1006
- patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
1007
- for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
1008
- print(patch_limits_tknzd[0].shape)
1009
- # cut tknzd crop position from conditioning
1010
- assert isinstance(cond, dict), 'cond must be dict to be fed into model'
1011
- cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
1012
- print(cut_cond.shape)
1013
-
1014
- adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
1015
- adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
1016
- print(adapted_cond.shape)
1017
- adapted_cond = self.get_learned_conditioning(adapted_cond)
1018
- print(adapted_cond.shape)
1019
- adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
1020
- print(adapted_cond.shape)
1021
-
1022
- cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
1023
-
1024
- else:
1025
- cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
1026
-
1027
- # apply model by loop over crops
1028
- output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
1029
- assert not isinstance(output_list[0],
1030
- tuple) # todo cant deal with multiple model outputs check this never happens
1031
-
1032
- o = torch.stack(output_list, axis=-1)
1033
- o = o * weighting
1034
- # Reverse reshape to img shape
1035
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
1036
- # stitch crops together
1037
- x_recon = fold(o) / normalization
1038
-
1039
- else:
1040
- x_recon_0, x_recon_1 = self.model(x_noisy_0, x_noisy_1, t, **cond)
1041
-
1042
- if isinstance(x_recon_0, tuple) and not return_ids:
1043
- return x_recon_0[0], x_recon_1[0]
1044
- else:
1045
- return x_recon_0, x_recon_1
1046
-
1047
- def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
1048
- return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
1049
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
1050
-
1051
- def _prior_bpd(self, x_start):
1052
- """
1053
- Get the prior KL term for the variational lower-bound, measured in
1054
- bits-per-dim.
1055
- This term can't be optimized, as it only depends on the encoder.
1056
- :param x_start: the [N x C x ...] tensor of inputs.
1057
- :return: a batch of [N] KL values (in bits), one per batch element.
1058
- """
1059
- batch_size = x_start.shape[0]
1060
- t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1061
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1062
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
1063
- return mean_flat(kl_prior) / np.log(2.0)
1064
-
1065
- def p_losses(self, x_start_0, x_start_1, cond, t, noise=None):
1066
- noise_0 = default(noise, lambda: torch.randn_like(x_start_0))
1067
- noise_1 = default(noise, lambda: torch.randn_like(x_start_1))
1068
- x_noisy_0 = self.q_sample(x_start=x_start_0, t=t, noise=noise_0)
1069
- x_noisy_1 = self.q_sample(x_start=x_start_1, t=t, noise=noise_1)
1070
- model_output_0, model_output_1 = self.apply_model(x_noisy_0, x_noisy_1, t, cond)
1071
-
1072
- loss_dict = {}
1073
- prefix = 'train' if self.training else 'val'
1074
-
1075
- if self.parameterization == "x0":
1076
- target_0 = x_start_0
1077
- target_1 = x_start_1
1078
- elif self.parameterization == "eps":
1079
- target_0 = noise_0
1080
- target_1 = noise_1
1081
- else:
1082
- raise NotImplementedError()
1083
-
1084
- loss_simple_0 = self.get_loss(model_output_0, target_0, mean=False).mean([1, 2, 3])
1085
- loss_simple_1 = self.get_loss(model_output_1, target_1, mean=False).mean([1, 2, 3])
1086
- loss_simple = (loss_simple_0 + loss_simple_1) / 2
1087
-
1088
- loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
1089
-
1090
- # logvar_t = self.logvar[t].to(self.device)
1091
- # 确保 self.logvar 和 self.device 在同一个设备上
1092
- self.logvar = self.logvar.to(self.device)
1093
- logvar_t = self.logvar[t]
1094
- loss = loss_simple / torch.exp(logvar_t) + logvar_t
1095
-
1096
- if self.learn_logvar:
1097
- loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
1098
- loss_dict.update({'logvar': self.logvar.data.mean()})
1099
-
1100
- loss = self.l_simple_weight * loss.mean()
1101
-
1102
- loss_vlb_0 = self.get_loss(model_output_0, target_0, mean=False).mean(dim=(1, 2, 3))
1103
- loss_vlb_1 = self.get_loss(model_output_1, target_1, mean=False).mean(dim=(1, 2, 3))
1104
- loss_vlb = (loss_vlb_0 + loss_vlb_1) / 2
1105
-
1106
- loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
1107
- loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
1108
- loss += (self.original_elbo_weight * loss_vlb)
1109
- loss_dict.update({f'{prefix}/loss': loss})
1110
-
1111
- return loss, loss_dict
1112
-
1113
- def p_mean_variance(self, x_0, x_1, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
1114
- return_x0=False, score_corrector=None, corrector_kwargs=None):
1115
- t_in = t
1116
- model_out_0, model_out_1 = self.apply_model(x_0, x_1, t_in, c, return_ids=return_codebook_ids)
1117
-
1118
- if score_corrector is not None:
1119
- assert self.parameterization == "eps"
1120
- model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
1121
-
1122
- if return_codebook_ids:
1123
- model_out, logits = model_out
1124
-
1125
- if self.parameterization == "eps":
1126
- x_recon_0 = self.predict_start_from_noise(x_0, t=t, noise=model_out_0)
1127
- x_recon_1 = self.predict_start_from_noise(x_1, t=t, noise=model_out_1)
1128
- elif self.parameterization == "x0":
1129
- x_recon = model_out
1130
- else:
1131
- raise NotImplementedError()
1132
- if clip_denoised:
1133
- x_recon.clamp_(-1., 1.)
1134
- if quantize_denoised:
1135
- x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1136
-
1137
- model_mean_0, posterior_variance_0, posterior_log_variance_0 = self.q_posterior(x_start=x_recon_0, x_t=x_0, t=t)
1138
- model_mean_1, posterior_variance_1, posterior_log_variance_1 = self.q_posterior(x_start=x_recon_1, x_t=x_1, t=t)
1139
- if return_codebook_ids:
1140
- return model_mean, posterior_variance, posterior_log_variance, logits
1141
- elif return_x0:
1142
- return model_mean, posterior_variance, posterior_log_variance, x_recon
1143
- else:
1144
- return model_mean_0, posterior_variance_0, posterior_log_variance_0, model_mean_1, posterior_variance_1, posterior_log_variance_1
1145
-
1146
- @torch.no_grad()
1147
- def p_sample(self, x_0, x_1, c, t, clip_denoised=False, repeat_noise=False,
1148
- return_codebook_ids=False, quantize_denoised=False, return_x0=False,
1149
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
1150
- b, *_, device = *x_0.shape, x_0.device
1151
- outputs = self.p_mean_variance(x_0=x_0, x_1=x_1, c=c, t=t, clip_denoised=clip_denoised,
1152
- return_codebook_ids=return_codebook_ids,
1153
- quantize_denoised=quantize_denoised,
1154
- return_x0=return_x0,
1155
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1156
-
1157
- if return_codebook_ids:
1158
- raise DeprecationWarning("Support dropped.")
1159
- model_mean, _, model_log_variance, logits = outputs
1160
- elif return_x0:
1161
- model_mean, _, model_log_variance, x0 = outputs
1162
- else:
1163
- model_mean_0, _, model_log_variance_0, model_mean_1, _, model_log_variance_1 = outputs
1164
-
1165
- noise = noise_like(x_0.shape, device, repeat_noise) * temperature
1166
- if noise_dropout > 0.:
1167
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1168
- # no noise when t == 0
1169
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_0.shape) - 1)))
1170
-
1171
- if return_codebook_ids:
1172
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1173
- if return_x0:
1174
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
1175
- else:
1176
- return model_mean_0 + nonzero_mask * (0.5 * model_log_variance_0).exp() * noise, \
1177
- model_mean_1 + nonzero_mask * (0.5 * model_log_variance_1).exp() * noise
1178
-
1179
- @torch.no_grad()
1180
- def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
1181
- img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
1182
- score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
1183
- log_every_t=None):
1184
- if not log_every_t:
1185
- log_every_t = self.log_every_t
1186
- timesteps = self.num_timesteps
1187
- if batch_size is not None:
1188
- b = batch_size if batch_size is not None else shape[0]
1189
- shape = [batch_size] + list(shape)
1190
- else:
1191
- b = batch_size = shape[0]
1192
- if x_T is None:
1193
- img = torch.randn(shape, device=self.device)
1194
- else:
1195
- img = x_T
1196
- intermediates = []
1197
- if cond is not None:
1198
- if isinstance(cond, dict):
1199
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1200
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1201
- else:
1202
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1203
-
1204
- if start_T is not None:
1205
- timesteps = min(timesteps, start_T)
1206
- iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1207
- total=timesteps) if verbose else reversed(
1208
- range(0, timesteps))
1209
- if type(temperature) == float:
1210
- temperature = [temperature] * timesteps
1211
-
1212
- for i in iterator:
1213
- ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1214
- if self.shorten_cond_schedule:
1215
- assert self.model.conditioning_key != 'hybrid'
1216
- tc = self.cond_ids[ts].to(cond.device)
1217
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1218
-
1219
- img, x0_partial = self.p_sample(img, cond, ts,
1220
- clip_denoised=self.clip_denoised,
1221
- quantize_denoised=quantize_denoised, return_x0=True,
1222
- temperature=temperature[i], noise_dropout=noise_dropout,
1223
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1224
- if mask is not None:
1225
- assert x0 is not None
1226
- img_orig = self.q_sample(x0, ts)
1227
- img = img_orig * mask + (1. - mask) * img
1228
-
1229
- if i % log_every_t == 0 or i == timesteps - 1:
1230
- intermediates.append(x0_partial)
1231
- if callback: callback(i)
1232
- if img_callback: img_callback(img, i)
1233
- return img, intermediates
1234
-
1235
- @torch.no_grad()
1236
- def p_sample_loop(self, cond, shape, return_intermediates=False,
1237
- x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1238
- mask=None, x0=None, img_callback=None, start_T=None,
1239
- log_every_t=None):
1240
-
1241
- if not log_every_t:
1242
- log_every_t = self.log_every_t
1243
- device = self.betas.device
1244
- b = shape[0]
1245
-
1246
- if x_T is None:
1247
- img_0 = torch.randn(shape, device=device)
1248
- img_1 = torch.randn(shape, device=device)
1249
- else:
1250
- img= x_T
1251
-
1252
- intermediates = [img_0]
1253
- if timesteps is None:
1254
- timesteps = self.num_timesteps
1255
-
1256
- if start_T is not None:
1257
- timesteps = min(timesteps, start_T)
1258
- iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1259
- range(0, timesteps))
1260
-
1261
- if mask is not None:
1262
- assert x0 is not None
1263
- assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1264
-
1265
- for i in iterator:
1266
- ts = torch.full((b,), i, device=device, dtype=torch.long)
1267
- if self.shorten_cond_schedule:
1268
- assert self.model.conditioning_key != 'hybrid'
1269
- tc = self.cond_ids[ts].to(cond.device)
1270
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1271
-
1272
- img_0, img_1 = self.p_sample(img_0, img_1, cond, ts,
1273
- clip_denoised=self.clip_denoised,
1274
- quantize_denoised=quantize_denoised)
1275
-
1276
- if mask is not None:
1277
- img_orig = self.q_sample(x0, ts)
1278
- img = img_orig * mask + (1. - mask) * img
1279
-
1280
- if i % log_every_t == 0 or i == timesteps - 1:
1281
- intermediates.append(img_0)
1282
- if callback: callback(i)
1283
- if callback: img_callback(img, i)
1284
-
1285
- if return_intermediates:
1286
- return img_0, intermediates
1287
- return img_0
1288
-
1289
- @torch.no_grad()
1290
- def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1291
- verbose=True, timesteps=None, quantize_denoised=False,
1292
- mask=None, x0=None, shape=None,**kwargs):
1293
- if shape is None:
1294
- shape = (batch_size, self.channels, self.image_size, self.image_size)
1295
- if cond is not None:
1296
- if isinstance(cond, dict):
1297
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1298
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1299
- else:
1300
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1301
- return self.p_sample_loop(cond,
1302
- shape,
1303
- return_intermediates=return_intermediates, x_T=x_T,
1304
- verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1305
- mask=mask, x0=x0)
1306
-
1307
- @torch.no_grad()
1308
- def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
1309
-
1310
- if ddim:
1311
- ddim_sampler = DDIMSampler(self)
1312
- shape = (self.channels, self.image_size, self.image_size)
1313
- samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
1314
- shape,cond,verbose=False,**kwargs)
1315
-
1316
- else:
1317
- samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1318
- return_intermediates=True,**kwargs)
1319
-
1320
- return samples, intermediates
1321
-
1322
-
1323
- @torch.no_grad()
1324
- def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1325
- quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
1326
- plot_diffusion_rows=False, **kwargs):
1327
-
1328
- use_ddim = False
1329
-
1330
- log = dict()
1331
- # z_0, z_1, c, x_0, x_0_rec, x_1, x_1_rec, xc
1332
- # z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1333
- z_0, z_1, c, x_0, x_0_rec, x_1, x_1_rec, xc = self.get_input(batch, self.first_stage_key,
1334
- return_first_stage_outputs=True,
1335
- force_c_encode=True,
1336
- return_original_cond=True,
1337
- bs=N, uncond=0)
1338
- N = min(x_0.shape[0], N)
1339
- n_row = min(x_0.shape[0], n_row)
1340
- log["inputs"] = x_0
1341
- log["reals"] = xc["c_concat"]
1342
- log["reconstruction"] = x_0_rec
1343
- if self.model.conditioning_key is not None:
1344
- if hasattr(self.cond_stage_model, "decode"):
1345
- xc = self.cond_stage_model.decode(c)
1346
- log["conditioning"] = xc
1347
- elif self.cond_stage_key in ["caption"]:
1348
- xc = log_txt_as_img((x_0.shape[2], x_0.shape[3]), batch["caption"])
1349
- log["conditioning"] = xc
1350
- elif self.cond_stage_key == 'class_label':
1351
- xc = log_txt_as_img((x_0.shape[2], x_0.shape[3]), batch["human_label"])
1352
- log['conditioning'] = xc
1353
- elif isimage(xc):
1354
- log["conditioning"] = xc
1355
- if ismap(xc):
1356
- log["original_conditioning"] = self.to_rgb(xc)
1357
-
1358
- if plot_diffusion_rows:
1359
- # get diffusion row
1360
- diffusion_row = list()
1361
- z_start = z[:n_row]
1362
- for t in range(self.num_timesteps):
1363
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1364
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1365
- t = t.to(self.device).long()
1366
- noise = torch.randn_like(z_start)
1367
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1368
- diffusion_row.append(self.decode_first_stage(z_noisy))
1369
-
1370
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1371
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1372
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1373
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1374
- log["diffusion_row"] = diffusion_grid
1375
-
1376
- if sample:
1377
- # get denoise row
1378
- with self.ema_scope("Plotting"):
1379
- samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1380
- ddim_steps=ddim_steps,eta=ddim_eta)
1381
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1382
- x_samples = self.decode_first_stage(samples)
1383
- log["samples"] = x_samples
1384
- if plot_denoise_rows:
1385
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1386
- log["denoise_row"] = denoise_grid
1387
-
1388
- if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1389
- self.first_stage_model, IdentityFirstStage):
1390
- # also display when quantizing x0 while sampling
1391
- with self.ema_scope("Plotting Quantized Denoised"):
1392
- samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1393
- ddim_steps=ddim_steps,eta=ddim_eta,
1394
- quantize_denoised=True)
1395
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1396
- # quantize_denoised=True)
1397
- x_samples = self.decode_first_stage(samples.to(self.device))
1398
- log["samples_x0_quantized"] = x_samples
1399
-
1400
- if inpaint:
1401
- # make a simple center square
1402
- b, h, w = z.shape[0], z.shape[2], z.shape[3]
1403
- mask = torch.ones(N, h, w).to(self.device)
1404
- # zeros will be filled in
1405
- mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1406
- mask = mask[:, None, ...]
1407
- with self.ema_scope("Plotting Inpaint"):
1408
-
1409
- samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
1410
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1411
- x_samples = self.decode_first_stage(samples.to(self.device))
1412
- log["samples_inpainting"] = x_samples
1413
- log["mask"] = mask
1414
-
1415
- # outpaint
1416
- with self.ema_scope("Plotting Outpaint"):
1417
- samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
1418
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1419
- x_samples = self.decode_first_stage(samples.to(self.device))
1420
- log["samples_outpainting"] = x_samples
1421
-
1422
- if plot_progressive_rows:
1423
- with self.ema_scope("Plotting Progressives"):
1424
- img, progressives = self.progressive_denoising(c,
1425
- shape=(self.channels, self.image_size, self.image_size),
1426
- batch_size=N)
1427
- prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1428
- log["progressive_row"] = prog_row
1429
-
1430
- if return_keys:
1431
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1432
- return log
1433
- else:
1434
- return {key: log[key] for key in return_keys}
1435
- return log
1436
-
1437
- def configure_optimizers(self):
1438
- lr = self.learning_rate
1439
- params = list(self.model.parameters())
1440
- if self.cond_stage_trainable:
1441
- print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1442
- params = params + list(self.cond_stage_model.parameters())
1443
- if self.learn_logvar:
1444
- print('Diffusion model optimizing logvar')
1445
- params.append(self.logvar)
1446
- opt = torch.optim.AdamW(params, lr=lr)
1447
- if self.use_scheduler:
1448
- assert 'target' in self.scheduler_config
1449
- scheduler = instantiate_from_config(self.scheduler_config)
1450
-
1451
- print("Setting up LambdaLR scheduler...")
1452
- scheduler = [
1453
- {
1454
- 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1455
- 'interval': 'step',
1456
- 'frequency': 1
1457
- }]
1458
- return [opt], scheduler
1459
- return opt
1460
-
1461
- @torch.no_grad()
1462
- def to_rgb(self, x):
1463
- x = x.float()
1464
- if not hasattr(self, "colorize"):
1465
- self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1466
- x = nn.functional.conv2d(x, weight=self.colorize)
1467
- x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1468
- return x
1469
-
1470
-
1471
- class DiffusionWrapper(pl.LightningModule):
1472
- def __init__(self, diff_model_config, conditioning_key):
1473
- super().__init__()
1474
- self.diffusion_model = instantiate_from_config(diff_model_config)
1475
- self.conditioning_key = conditioning_key
1476
- assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'hybrid_three_for_mask', 'adm']
1477
-
1478
- def forward(self, x_0, x_1, t, c_concat: list = None, c_crossattn: list = None):
1479
- if self.conditioning_key is None:
1480
- out = self.diffusion_model(x, t)
1481
- elif self.conditioning_key == 'concat':
1482
- xc = torch.cat([x] + c_concat, dim=1)
1483
- out = self.diffusion_model(xc, t)
1484
- elif self.conditioning_key == 'crossattn':
1485
- cc = torch.cat(c_crossattn, 1)
1486
- out = self.diffusion_model(x, t, context=cc)
1487
- elif self.conditioning_key == 'hybrid':
1488
- xc_0 = torch.cat([x_0] + c_concat, dim=1)
1489
- xc_1 = torch.cat([x_1] + c_concat, dim=1)
1490
- cc = torch.cat(c_crossattn, 1)
1491
- out_1, out_2 = self.diffusion_model(xc_0, xc_1, t, context=cc)
1492
- elif self.conditioning_key == 'hybrid_three_for_mask':
1493
- xc_0 = torch.cat([x_0] + c_concat, dim=1)
1494
- xc_1 = torch.cat([x_0, x_1] + c_concat, dim=1)
1495
- cc = torch.cat(c_crossattn, 1)
1496
- out_1, out_2 = self.diffusion_model(xc_0, xc_1, t, context=cc)
1497
- elif self.conditioning_key == 'adm':
1498
- cc = c_crossattn[0]
1499
- out = self.diffusion_model(x, t, y=cc)
1500
- else:
1501
- raise NotImplementedError()
1502
-
1503
- return out_1, out_2
1504
-
1505
-
1506
- class Layout2ImgDiffusion(LatentDiffusion):
1507
- # TODO: move all layout-specific hacks to this class
1508
- def __init__(self, cond_stage_key, *args, **kwargs):
1509
- assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
1510
- super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
1511
-
1512
- def log_images(self, batch, N=8, *args, **kwargs):
1513
- logs = super().log_images(batch=batch, N=N, *args, **kwargs)
1514
-
1515
- key = 'train' if self.training else 'validation'
1516
- dset = self.trainer.datamodule.datasets[key]
1517
- mapper = dset.conditional_builders[self.cond_stage_key]
1518
-
1519
- bbox_imgs = []
1520
- map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
1521
- for tknzd_bbox in batch[self.cond_stage_key][:N]:
1522
- bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
1523
- bbox_imgs.append(bboximg)
1524
-
1525
- cond_img = torch.stack(bbox_imgs, dim=0)
1526
- logs['bbox_image'] = cond_img
1527
- return logs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable_diffusion/ldm/models/diffusion/ddpm_pam_separate_mask_block.py DELETED
@@ -1,1608 +0,0 @@
1
- """
2
- wild mixture of
3
- https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
- https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
- https://github.com/CompVis/taming-transformers
6
- -- merci
7
- """
8
-
9
- # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
10
- # See more details in LICENSE.
11
-
12
- import torch
13
- import torch.nn as nn
14
- import torch.nn.functional as F
15
- import numpy as np
16
- import pytorch_lightning as pl
17
- from torch.optim.lr_scheduler import LambdaLR
18
- from einops import rearrange, repeat
19
- from contextlib import contextmanager
20
- from functools import partial
21
- from tqdm import tqdm
22
- from torchvision.utils import make_grid
23
- from pytorch_lightning.utilities.distributed import rank_zero_only
24
-
25
- from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
26
- from ldm.modules.ema import LitEma
27
- from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
28
- from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
29
- from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
30
- from ldm.models.diffusion.ddim import DDIMSampler
31
-
32
-
33
- __conditioning_keys__ = {'concat': 'c_concat',
34
- 'crossattn': 'c_crossattn',
35
- 'adm': 'y'}
36
-
37
-
38
- def disabled_train(self, mode=True):
39
- """Overwrite model.train with this function to make sure train/eval mode
40
- does not change anymore."""
41
- return self
42
-
43
-
44
- def uniform_on_device(r1, r2, shape, device):
45
- return (r1 - r2) * torch.rand(*shape, device=device) + r2
46
-
47
-
48
- class DDPM(pl.LightningModule):
49
- # classic DDPM with Gaussian diffusion, in image space
50
- def __init__(self,
51
- unet_config,
52
- timesteps=1000,
53
- beta_schedule="linear",
54
- loss_type="l2",
55
- ckpt_path=None,
56
- ignore_keys=[],
57
- load_only_unet=False,
58
- monitor="val/loss",
59
- use_ema=True,
60
- first_stage_key="image",
61
- image_size=256,
62
- channels=3,
63
- log_every_t=100,
64
- clip_denoised=True,
65
- linear_start=1e-4,
66
- linear_end=2e-2,
67
- cosine_s=8e-3,
68
- given_betas=None,
69
- original_elbo_weight=0.,
70
- v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
71
- l_simple_weight=1.,
72
- conditioning_key=None,
73
- parameterization="eps", # all assuming fixed variance schedules
74
- scheduler_config=None,
75
- use_positional_encodings=False,
76
- learn_logvar=False,
77
- logvar_init=0.,
78
- load_ema=True,
79
- ):
80
- super().__init__()
81
- assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
82
- self.parameterization = parameterization
83
- print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
84
- if not self.parameterization == "eps":
85
- NotImplementedError("omp not supported")
86
-
87
- self.cond_stage_model = None
88
- self.clip_denoised = clip_denoised
89
- self.log_every_t = log_every_t
90
- self.first_stage_key = first_stage_key
91
- self.image_size = image_size # try conv?
92
- self.channels = channels
93
- self.use_positional_encodings = use_positional_encodings
94
- self.model = DiffusionWrapper(unet_config, conditioning_key)
95
- count_params(self.model, verbose=True)
96
- self.use_ema = use_ema
97
-
98
- self.use_scheduler = scheduler_config is not None
99
- if self.use_scheduler:
100
- self.scheduler_config = scheduler_config
101
-
102
- self.v_posterior = v_posterior
103
- self.original_elbo_weight = original_elbo_weight
104
- self.l_simple_weight = l_simple_weight
105
-
106
- if monitor is not None:
107
- self.monitor = monitor
108
-
109
- if self.use_ema and load_ema:
110
- self.model_ema = LitEma(self.model)
111
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
112
-
113
- if ckpt_path is not None:
114
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
115
-
116
- # If initialing from EMA-only checkpoint, create EMA model after loading.
117
- if self.use_ema and not load_ema:
118
- self.model_ema = LitEma(self.model)
119
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
120
-
121
- self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
122
- linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
123
-
124
- self.loss_type = loss_type
125
-
126
- self.learn_logvar = learn_logvar
127
- self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
128
- if self.learn_logvar:
129
- self.logvar = nn.Parameter(self.logvar, requires_grad=True)
130
-
131
-
132
- def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
133
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
134
- if exists(given_betas):
135
- betas = given_betas
136
- else:
137
- betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
138
- cosine_s=cosine_s)
139
- alphas = 1. - betas
140
- alphas_cumprod = np.cumprod(alphas, axis=0)
141
- alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
142
-
143
- timesteps, = betas.shape
144
- self.num_timesteps = int(timesteps)
145
- self.linear_start = linear_start
146
- self.linear_end = linear_end
147
- assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
148
-
149
- to_torch = partial(torch.tensor, dtype=torch.float32)
150
-
151
- self.register_buffer('betas', to_torch(betas))
152
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
153
- self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
154
-
155
- # calculations for diffusion q(x_t | x_{t-1}) and others
156
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
157
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
158
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
159
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
160
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
161
-
162
- # calculations for posterior q(x_{t-1} | x_t, x_0)
163
- posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
164
- 1. - alphas_cumprod) + self.v_posterior * betas
165
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
166
- self.register_buffer('posterior_variance', to_torch(posterior_variance))
167
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
168
- self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
169
- self.register_buffer('posterior_mean_coef1', to_torch(
170
- betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
171
- self.register_buffer('posterior_mean_coef2', to_torch(
172
- (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
173
-
174
- if self.parameterization == "eps":
175
- lvlb_weights = self.betas ** 2 / (
176
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
177
- elif self.parameterization == "x0":
178
- lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
179
- else:
180
- raise NotImplementedError("mu not supported")
181
- # TODO how to choose this term
182
- lvlb_weights[0] = lvlb_weights[1]
183
- self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
184
- assert not torch.isnan(self.lvlb_weights).all()
185
-
186
- @contextmanager
187
- def ema_scope(self, context=None):
188
- if self.use_ema:
189
- self.model_ema.store(self.model.parameters())
190
- self.model_ema.copy_to(self.model)
191
- if context is not None:
192
- print(f"{context}: Switched to EMA weights")
193
- try:
194
- yield None
195
- finally:
196
- if self.use_ema:
197
- self.model_ema.restore(self.model.parameters())
198
- if context is not None:
199
- print(f"{context}: Restored training weights")
200
-
201
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
202
- sd = torch.load(path, map_location="cpu")
203
- if "state_dict" in list(sd.keys()):
204
- sd = sd["state_dict"]
205
- keys = list(sd.keys())
206
-
207
- # Our model adds additional channels to the first layer to condition on an input image.
208
- # For the first layer, copy existing channel weights and initialize new channel weights to zero.
209
- input_keys = [
210
- "model.diffusion_model.input_blocks.0.0.weight",
211
- "model_ema.diffusion_modelinput_blocks00weight",
212
- ]
213
-
214
- branch_1_keys = [
215
- "model.diffusion_model.input_blocks_branch_1",
216
- "model.diffusion_model.output_blocks_branch_1",
217
- "model.diffusion_model.out_branch_1",
218
- "model_ema.diffusion_modelinput_blocks_branch_100weight",
219
- "model_ema.diffusion_modelout_branch_10weight",
220
- "model_ema.diffusion_modelout_branch_12weight",
221
-
222
-
223
-
224
- ]
225
- mask_block_keys = [
226
- "model.diffusion_model.mask_blocks"
227
- "model_ema.diffusion_modelmask_blocks00weight",
228
- ]
229
- ignore_keys += mask_block_keys
230
- self_sd = self.state_dict()
231
-
232
-
233
- for input_key in input_keys:
234
- if input_key not in sd or input_key not in self_sd:
235
- continue
236
-
237
- input_weight = self_sd[input_key]
238
-
239
- if input_weight.size() != sd[input_key].size():
240
- print(f"Manual init: {input_key}")
241
- input_weight.zero_()
242
- input_weight[:, :4, :, :].copy_(sd[input_key])
243
- ignore_keys.append(input_key)
244
-
245
-
246
- # for branch_1_key in branch_1_keys:
247
- # start_with_branch_1_keys = [k for k in self_sd if k.startswith(branch_1_key)]
248
- # main_keys = [k.replace("_branch_1", "") for k in start_with_branch_1_keys]
249
-
250
- # for start_with_branch_1_key, main_key in zip(start_with_branch_1_keys, main_keys):
251
- # if start_with_branch_1_key not in self_sd or main_key not in sd:
252
- # continue
253
-
254
- # branch_1_weight = self_sd[start_with_branch_1_key]
255
- # if branch_1_weight.size() != sd[main_key].size():
256
- # print(f"Manual init: {start_with_branch_1_key}")
257
- # branch_1_weight.zero_()
258
- # branch_1_weight[:, :4, :, :].copy_(sd[main_key])
259
- # ignore_keys.append(start_with_branch_1_key)
260
- # else:
261
- # branch_1_weight.zero_()
262
- # branch_1_weight.copy_(sd[main_key])
263
- # ignore_keys.append(start_with_branch_1_key)
264
-
265
- for k in keys:
266
- for ik in ignore_keys:
267
- if k.startswith(ik):
268
- print("Deleting key {} from state_dict.".format(k))
269
- del sd[k]
270
-
271
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
272
- sd, strict=False)
273
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
274
- if len(missing) > 0:
275
- print(f"Missing Keys: {missing}")
276
- if len(unexpected) > 0:
277
- print(f"Unexpected Keys: {unexpected}")
278
-
279
-
280
- def q_mean_variance(self, x_start, t):
281
- """
282
- Get the distribution q(x_t | x_0).
283
- :param x_start: the [N x C x ...] tensor of noiseless inputs.
284
- :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
285
- :return: A tuple (mean, variance, log_variance), all of x_start's shape.
286
- """
287
- mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
288
- variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
289
- log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
290
- return mean, variance, log_variance
291
-
292
- def predict_start_from_noise(self, x_t, t, noise):
293
- return (
294
- extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
295
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
296
- )
297
-
298
- def q_posterior(self, x_start, x_t, t):
299
- posterior_mean = (
300
- extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
301
- extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
302
- )
303
- posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
304
- posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
305
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
306
-
307
- def p_mean_variance(self, x, t, clip_denoised: bool):
308
- model_out = self.model(x, t)
309
- if self.parameterization == "eps":
310
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
311
- elif self.parameterization == "x0":
312
- x_recon = model_out
313
- if clip_denoised:
314
- x_recon.clamp_(-1., 1.)
315
-
316
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
317
- return model_mean, posterior_variance, posterior_log_variance
318
-
319
- @torch.no_grad()
320
- def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
321
- b, *_, device = *x.shape, x.device
322
- model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
323
- noise = noise_like(x.shape, device, repeat_noise)
324
- # no noise when t == 0
325
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
326
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
327
-
328
- @torch.no_grad()
329
- def p_sample_loop(self, shape, return_intermediates=False):
330
- device = self.betas.device
331
- b = shape[0]
332
- img = torch.randn(shape, device=device)
333
- intermediates = [img]
334
- for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
335
- img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
336
- clip_denoised=self.clip_denoised)
337
- if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
338
- intermediates.append(img)
339
- if return_intermediates:
340
- return img, intermediates
341
- return img
342
-
343
- @torch.no_grad()
344
- def sample(self, batch_size=16, return_intermediates=False):
345
- image_size = self.image_size
346
- channels = self.channels
347
- return self.p_sample_loop((batch_size, channels, image_size, image_size),
348
- return_intermediates=return_intermediates)
349
-
350
- def q_sample(self, x_start, t, noise=None):
351
- noise = default(noise, lambda: torch.randn_like(x_start))
352
- return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
353
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
354
-
355
- def get_loss(self, pred, target, mean=True):
356
- if self.loss_type == 'l1':
357
- loss = (target - pred).abs()
358
- if mean:
359
- loss = loss.mean()
360
- elif self.loss_type == 'l2':
361
- if mean:
362
- loss = torch.nn.functional.mse_loss(target, pred)
363
- else:
364
- loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
365
- else:
366
- raise NotImplementedError("unknown loss type '{loss_type}'")
367
-
368
- return loss
369
-
370
- def p_losses(self, x_start, t, noise=None):
371
- noise = default(noise, lambda: torch.randn_like(x_start))
372
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
373
- model_out = self.model(x_noisy, t)
374
-
375
- loss_dict = {}
376
- if self.parameterization == "eps":
377
- target = noise
378
- elif self.parameterization == "x0":
379
- target = x_start
380
- else:
381
- raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
382
-
383
- loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
384
-
385
- log_prefix = 'train' if self.training else 'val'
386
-
387
- loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
388
- loss_simple = loss.mean() * self.l_simple_weight
389
-
390
- loss_vlb = (self.lvlb_weights[t] * loss).mean()
391
- loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
392
-
393
- loss = loss_simple + self.original_elbo_weight * loss_vlb
394
-
395
- loss_dict.update({f'{log_prefix}/loss': loss})
396
-
397
- return loss, loss_dict
398
-
399
- def forward(self, x, *args, **kwargs):
400
- # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
401
- # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
402
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
403
- return self.p_losses(x, t, *args, **kwargs)
404
-
405
- def get_input(self, batch, k):
406
- return batch[k]
407
-
408
- def shared_step(self, batch):
409
- x = self.get_input(batch, self.first_stage_key)
410
- loss, loss_dict = self(x)
411
- return loss, loss_dict
412
-
413
- def training_step(self, batch, batch_idx):
414
- loss, loss_dict = self.shared_step(batch)
415
-
416
- self.log_dict(loss_dict, prog_bar=True,
417
- logger=True, on_step=True, on_epoch=True)
418
-
419
- self.log("global_step", self.global_step,
420
- prog_bar=True, logger=True, on_step=True, on_epoch=False)
421
-
422
- if self.use_scheduler:
423
- lr = self.optimizers().param_groups[0]['lr']
424
- self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
425
-
426
- return loss
427
-
428
- @torch.no_grad()
429
- def validation_step(self, batch, batch_idx):
430
- _, loss_dict_no_ema = self.shared_step(batch)
431
- with self.ema_scope():
432
- _, loss_dict_ema = self.shared_step(batch)
433
- loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
434
- self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
435
- self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
436
-
437
- def on_train_batch_end(self, *args, **kwargs):
438
- if self.use_ema:
439
- self.model_ema(self.model)
440
-
441
- def _get_rows_from_list(self, samples):
442
- n_imgs_per_row = len(samples)
443
- denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
444
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
445
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
446
- return denoise_grid
447
-
448
- @torch.no_grad()
449
- def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
450
- log = dict()
451
- x = self.get_input(batch, self.first_stage_key)
452
- N = min(x.shape[0], N)
453
- n_row = min(x.shape[0], n_row)
454
- x = x.to(self.device)[:N]
455
- log["inputs"] = x
456
-
457
- # get diffusion row
458
- diffusion_row = list()
459
- x_start = x[:n_row]
460
-
461
- for t in range(self.num_timesteps):
462
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
463
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
464
- t = t.to(self.device).long()
465
- noise = torch.randn_like(x_start)
466
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
467
- diffusion_row.append(x_noisy)
468
-
469
- log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
470
-
471
- if sample:
472
- # get denoise row
473
- with self.ema_scope("Plotting"):
474
- samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
475
-
476
- log["samples"] = samples
477
- log["denoise_row"] = self._get_rows_from_list(denoise_row)
478
-
479
- if return_keys:
480
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
481
- return log
482
- else:
483
- return {key: log[key] for key in return_keys}
484
- return log
485
-
486
- def configure_optimizers(self):
487
- lr = self.learning_rate
488
- params = list(self.model.parameters())
489
- if self.learn_logvar:
490
- params = params + [self.logvar]
491
- opt = torch.optim.AdamW(params, lr=lr)
492
- return opt
493
-
494
-
495
- class LatentDiffusion(DDPM):
496
- """main class"""
497
- def __init__(self,
498
- first_stage_config,
499
- cond_stage_config,
500
- num_timesteps_cond=None,
501
- cond_stage_key="image",
502
- cond_stage_trainable=False,
503
- concat_mode=True,
504
- cond_stage_forward=None,
505
- conditioning_key=None,
506
- scale_factor=1.0,
507
- scale_by_std=False,
508
- load_ema=True,
509
- first_stage_downsample=False,
510
- mask_loss_factor=1.0,
511
- *args, **kwargs):
512
- self.num_timesteps_cond = default(num_timesteps_cond, 1)
513
- self.scale_by_std = scale_by_std
514
- assert self.num_timesteps_cond <= kwargs['timesteps']
515
- # for backwards compatibility after implementation of DiffusionWrapper
516
- if conditioning_key is None:
517
- conditioning_key = 'concat' if concat_mode else 'crossattn'
518
- if cond_stage_config == '__is_unconditional__':
519
- conditioning_key = None
520
- ckpt_path = kwargs.pop("ckpt_path", None)
521
- ignore_keys = kwargs.pop("ignore_keys", [])
522
- super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
523
- self.concat_mode = concat_mode
524
- self.cond_stage_trainable = cond_stage_trainable
525
- self.cond_stage_key = cond_stage_key
526
- try:
527
- self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
528
- except:
529
- self.num_downs = 0
530
- if not scale_by_std:
531
- self.scale_factor = scale_factor
532
- else:
533
- self.register_buffer('scale_factor', torch.tensor(scale_factor))
534
- self.instantiate_first_stage(first_stage_config)
535
- self.first_stage_downsample = first_stage_downsample
536
- self.mask_loss_factor = mask_loss_factor
537
- self.instantiate_cond_stage(cond_stage_config)
538
- self.cond_stage_forward = cond_stage_forward
539
- self.clip_denoised = False
540
- self.bbox_tokenizer = None
541
-
542
- self.restarted_from_ckpt = False
543
- if ckpt_path is not None:
544
- self.init_from_ckpt(ckpt_path, ignore_keys)
545
- self.restarted_from_ckpt = True
546
-
547
- if self.use_ema and not load_ema:
548
- self.model_ema = LitEma(self.model)
549
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
550
-
551
- def make_cond_schedule(self, ):
552
- self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
553
- ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
554
- self.cond_ids[:self.num_timesteps_cond] = ids
555
-
556
- @rank_zero_only
557
- @torch.no_grad()
558
- def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
559
- # only for very first batch
560
- if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
561
- assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
562
- # set rescale weight to 1./std of encodings
563
- print("### USING STD-RESCALING ###")
564
- x = super().get_input(batch, self.first_stage_key)
565
- x = x.to(self.device)
566
- encoder_posterior = self.encode_first_stage(x)
567
- z = self.get_first_stage_encoding(encoder_posterior).detach()
568
- del self.scale_factor
569
- self.register_buffer('scale_factor', 1. / z.flatten().std())
570
- print(f"setting self.scale_factor to {self.scale_factor}")
571
- print("### USING STD-RESCALING ###")
572
-
573
- def register_schedule(self,
574
- given_betas=None, beta_schedule="linear", timesteps=1000,
575
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
576
- super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
577
-
578
- self.shorten_cond_schedule = self.num_timesteps_cond > 1
579
- if self.shorten_cond_schedule:
580
- self.make_cond_schedule()
581
-
582
- def instantiate_first_stage(self, config):
583
- model = instantiate_from_config(config)
584
- self.first_stage_model = model.eval()
585
- self.first_stage_model.train = disabled_train
586
- for param in self.first_stage_model.parameters():
587
- param.requires_grad = False
588
-
589
- def instantiate_cond_stage(self, config):
590
- if not self.cond_stage_trainable:
591
- if config == "__is_first_stage__":
592
- print("Using first stage also as cond stage.")
593
- self.cond_stage_model = self.first_stage_model
594
- elif config == "__is_unconditional__":
595
- print(f"Training {self.__class__.__name__} as an unconditional model.")
596
- self.cond_stage_model = None
597
- # self.be_unconditional = True
598
- else:
599
- model = instantiate_from_config(config)
600
- self.cond_stage_model = model.eval()
601
- self.cond_stage_model.train = disabled_train
602
- for param in self.cond_stage_model.parameters():
603
- param.requires_grad = False
604
- else:
605
- assert config != '__is_first_stage__'
606
- assert config != '__is_unconditional__'
607
- model = instantiate_from_config(config)
608
- self.cond_stage_model = model
609
-
610
- def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
611
- denoise_row = []
612
- for zd in tqdm(samples, desc=desc):
613
- denoise_row.append(self.decode_first_stage(zd.to(self.device),
614
- force_not_quantize=force_no_decoder_quantization))
615
- n_imgs_per_row = len(denoise_row)
616
- denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
617
- denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
618
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
619
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
620
- return denoise_grid
621
-
622
- def get_first_stage_encoding(self, encoder_posterior):
623
- if isinstance(encoder_posterior, DiagonalGaussianDistribution):
624
- z = encoder_posterior.sample()
625
- elif isinstance(encoder_posterior, torch.Tensor):
626
- z = encoder_posterior
627
- else:
628
- raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
629
- return self.scale_factor * z
630
-
631
- def get_learned_conditioning(self, c):
632
- if self.cond_stage_forward is None:
633
- if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
634
- c = self.cond_stage_model.encode(c)
635
- if isinstance(c, DiagonalGaussianDistribution):
636
- c = c.mode()
637
- else:
638
- c = self.cond_stage_model(c)
639
- else:
640
- assert hasattr(self.cond_stage_model, self.cond_stage_forward)
641
- c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
642
- return c
643
-
644
- def get_vision_conditioning(self, c):
645
- c = self.cond_stage_model.vision_forward(c)
646
- return c
647
-
648
- def meshgrid(self, h, w):
649
- y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
650
- x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
651
-
652
- arr = torch.cat([y, x], dim=-1)
653
- return arr
654
-
655
- def delta_border(self, h, w):
656
- """
657
- :param h: height
658
- :param w: width
659
- :return: normalized distance to image border,
660
- wtith min distance = 0 at border and max dist = 0.5 at image center
661
- """
662
- lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
663
- arr = self.meshgrid(h, w) / lower_right_corner
664
- dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
665
- dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
666
- edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
667
- return edge_dist
668
-
669
- def get_weighting(self, h, w, Ly, Lx, device):
670
- weighting = self.delta_border(h, w)
671
- weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
672
- self.split_input_params["clip_max_weight"], )
673
- weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
674
-
675
- if self.split_input_params["tie_braker"]:
676
- L_weighting = self.delta_border(Ly, Lx)
677
- L_weighting = torch.clip(L_weighting,
678
- self.split_input_params["clip_min_tie_weight"],
679
- self.split_input_params["clip_max_tie_weight"])
680
-
681
- L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
682
- weighting = weighting * L_weighting
683
- return weighting
684
-
685
- def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
686
- """
687
- :param x: img of size (bs, c, h, w)
688
- :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
689
- """
690
- bs, nc, h, w = x.shape
691
-
692
- # number of crops in image
693
- Ly = (h - kernel_size[0]) // stride[0] + 1
694
- Lx = (w - kernel_size[1]) // stride[1] + 1
695
-
696
- if uf == 1 and df == 1:
697
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
698
- unfold = torch.nn.Unfold(**fold_params)
699
-
700
- fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
701
-
702
- weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
703
- normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
704
- weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
705
-
706
- elif uf > 1 and df == 1:
707
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
708
- unfold = torch.nn.Unfold(**fold_params)
709
-
710
- fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
711
- dilation=1, padding=0,
712
- stride=(stride[0] * uf, stride[1] * uf))
713
- fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
714
-
715
- weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
716
- normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
717
- weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
718
-
719
- elif df > 1 and uf == 1:
720
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
721
- unfold = torch.nn.Unfold(**fold_params)
722
-
723
- fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
724
- dilation=1, padding=0,
725
- stride=(stride[0] // df, stride[1] // df))
726
- fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
727
-
728
- weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
729
- normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
730
- weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
731
-
732
- else:
733
- raise NotImplementedError
734
-
735
- return fold, unfold, normalization, weighting
736
-
737
- @torch.no_grad()
738
- def get_input(self, batch, keys, return_first_stage_outputs=False, force_c_encode=False,
739
- cond_key=None, return_original_cond=False, bs=None, uncond=0.05):
740
- x_0 = super().get_input(batch, keys[0])
741
- x_1 = super().get_input(batch, keys[1])
742
- if bs is not None:
743
- x_0 = x_0[:bs]
744
- x_1 = x_1[:bs]
745
- x_0 = x_0.to(self.device)
746
- x_1 = x_1.to(self.device)
747
-
748
- encoder_posterior = self.encode_first_stage(x_0)
749
- z_0 = self.get_first_stage_encoding(encoder_posterior).detach()
750
- if self.first_stage_downsample:
751
- z_1 = F.interpolate(x_1, scale_factor=1/8, mode='bilinear', align_corners=False)
752
- z_1 = torch.where(z_1 > 0.5, 1, -1).float() # Thresholding step
753
- else:
754
- z_1 = self.get_first_stage_encoding(self.encode_first_stage(x_1)).detach()
755
-
756
- cond_key = cond_key or self.cond_stage_key
757
- xc = super().get_input(batch, cond_key)
758
- if bs is not None:
759
- xc["c_crossattn"] = xc["c_crossattn"][:bs]
760
- xc["c_concat"] = xc["c_concat"][:bs]
761
- cond = {}
762
-
763
- # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%.
764
- random = torch.rand(x_0.size(0), device=x_0.device)
765
- prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1")
766
- input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1")
767
-
768
- null_prompt = self.get_learned_conditioning([""])
769
- cond["c_crossattn"] = [torch.where(prompt_mask, null_prompt, self.get_learned_conditioning(xc["c_crossattn"]).detach())]
770
- cond["c_concat"] = [input_mask * self.encode_first_stage((xc["c_concat"].to(self.device))).mode().detach()]
771
-
772
- out = [z_0, z_1, cond]
773
- if return_first_stage_outputs:
774
- x_0_rec = self.decode_first_stage(z_0)
775
-
776
- if self.first_stage_downsample:
777
- x_1_rec = F.interpolate(z_1, scale_factor=8, mode='bilinear', align_corners=False)
778
- x_1_rec = torch.where(x_1_rec > 0, 1, -1) # Thresholding step
779
- else:
780
- x_1_rec = self.decode_first_stage(z_1)
781
- out.extend([x_0, x_0_rec, x_1, x_1_rec])
782
- if return_original_cond:
783
- out.append(xc)
784
-
785
- return out
786
-
787
- @torch.no_grad()
788
- def forward_mask_decoder(self, input_image, output_image, c, time_step):
789
- time_step = torch.tensor(time_step).unsqueeze(0)
790
- t = time_step.to(self.device).long()
791
- # time_step to torch tensor
792
-
793
-
794
- noise_0 = default(None, lambda: torch.randn_like(output_image))
795
- output_image_noise = self.q_sample(x_start=output_image, t=t, noise=noise_0)
796
-
797
-
798
- xc = torch.cat([output_image_noise] + [input_image], dim=1) # Convert input_image to a list
799
- cc = torch.cat([c], 1)
800
-
801
- mask = self.model.diffusion_model.decode_mask(xc, t, context=cc)
802
-
803
-
804
- return mask
805
-
806
- @torch.no_grad()
807
- def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
808
- if predict_cids:
809
- if z.dim() == 4:
810
- z = torch.argmax(z.exp(), dim=1).long()
811
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
812
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
813
-
814
- z = 1. / self.scale_factor * z
815
-
816
- if hasattr(self, "split_input_params"):
817
- if self.split_input_params["patch_distributed_vq"]:
818
- ks = self.split_input_params["ks"] # eg. (128, 128)
819
- stride = self.split_input_params["stride"] # eg. (64, 64)
820
- uf = self.split_input_params["vqf"]
821
- bs, nc, h, w = z.shape
822
- if ks[0] > h or ks[1] > w:
823
- ks = (min(ks[0], h), min(ks[1], w))
824
- print("reducing Kernel")
825
-
826
- if stride[0] > h or stride[1] > w:
827
- stride = (min(stride[0], h), min(stride[1], w))
828
- print("reducing stride")
829
-
830
- fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
831
-
832
- z = unfold(z) # (bn, nc * prod(**ks), L)
833
- # 1. Reshape to img shape
834
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
835
-
836
- # 2. apply model loop over last dim
837
- if isinstance(self.first_stage_model, VQModelInterface):
838
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
839
- force_not_quantize=predict_cids or force_not_quantize)
840
- for i in range(z.shape[-1])]
841
- else:
842
-
843
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
844
- for i in range(z.shape[-1])]
845
-
846
- o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
847
- o = o * weighting
848
- # Reverse 1. reshape to img shape
849
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
850
- # stitch crops together
851
- decoded = fold(o)
852
- decoded = decoded / normalization # norm is shape (1, 1, h, w)
853
- return decoded
854
- else:
855
- if isinstance(self.first_stage_model, VQModelInterface):
856
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
857
- else:
858
- return self.first_stage_model.decode(z)
859
- # elif self.first_stage_downsample:
860
- # # 对于z.shape = [b, h//2, w//2],直接做上采样到[b, h, w]而不是用self.first_stage_model
861
- # z = F.interpolate(z, scale_factor=8, mode='bilinear', align_corners=False)
862
- # z = torch.where(z > 0.5, 1, 0) # Thresholding step
863
- # return z
864
-
865
- else:
866
- if isinstance(self.first_stage_model, VQModelInterface):
867
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
868
- else:
869
- return self.first_stage_model.decode(z)
870
-
871
- # same as above but without decorator
872
- def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
873
- if predict_cids:
874
- if z.dim() == 4:
875
- z = torch.argmax(z.exp(), dim=1).long()
876
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
877
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
878
-
879
- z = 1. / self.scale_factor * z
880
-
881
- if hasattr(self, "split_input_params"):
882
- if self.split_input_params["patch_distributed_vq"]:
883
- ks = self.split_input_params["ks"] # eg. (128, 128)
884
- stride = self.split_input_params["stride"] # eg. (64, 64)
885
- uf = self.split_input_params["vqf"]
886
- bs, nc, h, w = z.shape
887
- if ks[0] > h or ks[1] > w:
888
- ks = (min(ks[0], h), min(ks[1], w))
889
- print("reducing Kernel")
890
-
891
- if stride[0] > h or stride[1] > w:
892
- stride = (min(stride[0], h), min(stride[1], w))
893
- print("reducing stride")
894
-
895
- fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
896
-
897
- z = unfold(z) # (bn, nc * prod(**ks), L)
898
- # 1. Reshape to img shape
899
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
900
-
901
- # 2. apply model loop over last dim
902
- if isinstance(self.first_stage_model, VQModelInterface):
903
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
904
- force_not_quantize=predict_cids or force_not_quantize)
905
- for i in range(z.shape[-1])]
906
- else:
907
-
908
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
909
- for i in range(z.shape[-1])]
910
-
911
- o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
912
- o = o * weighting
913
- # Reverse 1. reshape to img shape
914
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
915
- # stitch crops together
916
- decoded = fold(o)
917
- decoded = decoded / normalization # norm is shape (1, 1, h, w)
918
- return decoded
919
- else:
920
- if isinstance(self.first_stage_model, VQModelInterface):
921
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
922
- else:
923
- return self.first_stage_model.decode(z)
924
-
925
- else:
926
- if isinstance(self.first_stage_model, VQModelInterface):
927
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
928
- else:
929
- return self.first_stage_model.decode(z)
930
-
931
- @torch.no_grad()
932
- def encode_first_stage(self, x):
933
- if hasattr(self, "split_input_params"):
934
- if self.split_input_params["patch_distributed_vq"]:
935
- ks = self.split_input_params["ks"] # eg. (128, 128)
936
- stride = self.split_input_params["stride"] # eg. (64, 64)
937
- df = self.split_input_params["vqf"]
938
- self.split_input_params['original_image_size'] = x.shape[-2:]
939
- bs, nc, h, w = x.shape
940
- if ks[0] > h or ks[1] > w:
941
- ks = (min(ks[0], h), min(ks[1], w))
942
- print("reducing Kernel")
943
-
944
- if stride[0] > h or stride[1] > w:
945
- stride = (min(stride[0], h), min(stride[1], w))
946
- print("reducing stride")
947
-
948
- fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
949
- z = unfold(x) # (bn, nc * prod(**ks), L)
950
- # Reshape to img shape
951
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
952
-
953
- output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
954
- for i in range(z.shape[-1])]
955
-
956
- o = torch.stack(output_list, axis=-1)
957
- o = o * weighting
958
-
959
- # Reverse reshape to img shape
960
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
961
- # stitch crops together
962
- decoded = fold(o)
963
- decoded = decoded / normalization
964
- return decoded
965
-
966
- else:
967
- return self.first_stage_model.encode(x)
968
- # elif first_stage_downsample:
969
- # # 对于x.shape = [b, h, w],直接做下采样到[b, h//2, w//2]而不是用self.first_stage_model
970
- # x = F.interpolate(x, scale_factor=1/8, mode='bilinear', align_corners=False)
971
- # # x = torch.where(x < 0.5, torch.zeros_like(x), torch.ones_like(x)) # Thresholding step
972
- # x = torch.where(x > 0.5, 1, -1) # Thresholding step
973
- # return x
974
- else:
975
- return self.first_stage_model.encode(x)
976
-
977
- def shared_step(self, batch, **kwargs):
978
- x_0, x_1, c = self.get_input(batch, self.first_stage_key)
979
- loss = self(x_0, x_1, c)
980
- return loss
981
-
982
- def forward(self, x_0, x_1, c, *args, **kwargs):
983
- t = torch.randint(0, self.num_timesteps, (x_0.shape[0],), device=self.device).long()
984
- if self.model.conditioning_key is not None:
985
- assert c is not None
986
- # in pix2pix, cond_stage_trainable and short_cond_schedule are false
987
- if self.cond_stage_trainable:
988
- c = self.get_learned_conditioning(c)
989
- if self.shorten_cond_schedule: # TODO: drop this option
990
- tc = self.cond_ids[t].to(self.device)
991
- c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
992
- return self.p_losses(x_0, x_1, c, t, *args, **kwargs)
993
-
994
- def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
995
- def rescale_bbox(bbox):
996
- x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
997
- y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
998
- w = min(bbox[2] / crop_coordinates[2], 1 - x0)
999
- h = min(bbox[3] / crop_coordinates[3], 1 - y0)
1000
- return x0, y0, w, h
1001
-
1002
- return [rescale_bbox(b) for b in bboxes]
1003
-
1004
- def apply_model(self, x_noisy_0, x_noisy_1, t, cond, return_ids=False):
1005
- if isinstance(cond, dict):
1006
- # hybrid case, cond is exptected to be a dict
1007
- pass
1008
- else:
1009
- if not isinstance(cond, list):
1010
- cond = [cond]
1011
- key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
1012
- cond = {key: cond}
1013
-
1014
- if hasattr(self, "split_input_params"):
1015
- assert len(cond) == 1 # todo can only deal with one conditioning atm
1016
- assert not return_ids
1017
- ks = self.split_input_params["ks"] # eg. (128, 128)
1018
- stride = self.split_input_params["stride"] # eg. (64, 64)
1019
-
1020
- h, w = x_noisy.shape[-2:]
1021
-
1022
- fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
1023
-
1024
- z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
1025
- # Reshape to img shape
1026
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
1027
- z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
1028
-
1029
- if self.cond_stage_key in ["image", "LR_image", "segmentation",
1030
- 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
1031
- c_key = next(iter(cond.keys())) # get key
1032
- c = next(iter(cond.values())) # get value
1033
- assert (len(c) == 1) # todo extend to list with more than one elem
1034
- c = c[0] # get element
1035
-
1036
- c = unfold(c)
1037
- c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
1038
-
1039
- cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
1040
-
1041
- elif self.cond_stage_key == 'coordinates_bbox':
1042
- assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
1043
-
1044
- # assuming padding of unfold is always 0 and its dilation is always 1
1045
- n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
1046
- full_img_h, full_img_w = self.split_input_params['original_image_size']
1047
- # as we are operating on latents, we need the factor from the original image size to the
1048
- # spatial latent size to properly rescale the crops for regenerating the bbox annotations
1049
- num_downs = self.first_stage_model.encoder.num_resolutions - 1
1050
- rescale_latent = 2 ** (num_downs)
1051
-
1052
- # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
1053
- # need to rescale the tl patch coordinates to be in between (0,1)
1054
- tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
1055
- rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
1056
- for patch_nr in range(z.shape[-1])]
1057
-
1058
- # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
1059
- patch_limits = [(x_tl, y_tl,
1060
- rescale_latent * ks[0] / full_img_w,
1061
- rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
1062
- # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
1063
-
1064
- # tokenize crop coordinates for the bounding boxes of the respective patches
1065
- patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
1066
- for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
1067
- print(patch_limits_tknzd[0].shape)
1068
- # cut tknzd crop position from conditioning
1069
- assert isinstance(cond, dict), 'cond must be dict to be fed into model'
1070
- cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
1071
- print(cut_cond.shape)
1072
-
1073
- adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
1074
- adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
1075
- print(adapted_cond.shape)
1076
- adapted_cond = self.get_learned_conditioning(adapted_cond)
1077
- print(adapted_cond.shape)
1078
- adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
1079
- print(adapted_cond.shape)
1080
-
1081
- cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
1082
-
1083
- else:
1084
- cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
1085
-
1086
- # apply model by loop over crops
1087
- output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
1088
- assert not isinstance(output_list[0],
1089
- tuple) # todo cant deal with multiple model outputs check this never happens
1090
-
1091
- o = torch.stack(output_list, axis=-1)
1092
- o = o * weighting
1093
- # Reverse reshape to img shape
1094
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
1095
- # stitch crops together
1096
- x_recon = fold(o) / normalization
1097
-
1098
- else:
1099
- x_recon_0, x_recon_1 = self.model(x_noisy_0, t, x_noisy_1, **cond)
1100
-
1101
- # predict_image = self.decode_first_stage(x_start_0[:1])
1102
- # predict_image = torch.clamp((predict_image + 1.0) / 2.0, min=0.0, max=1.0)
1103
- # from PIL import Image
1104
- # predict_image = 255.0 * rearrange(predict_image, "1 c h w -> h w c")
1105
- # predict_image = Image.fromarray(predict_image.type(torch.uint8).cpu().numpy())
1106
- # predict_image.save("predict_image.png")
1107
-
1108
-
1109
-
1110
- if isinstance(x_recon_0, tuple) and not return_ids:
1111
- return x_recon_0[0], x_recon_1[0]
1112
- else:
1113
- return x_recon_0, x_recon_1
1114
-
1115
- def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
1116
- return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
1117
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
1118
-
1119
- def _prior_bpd(self, x_start):
1120
- """
1121
- Get the prior KL term for the variational lower-bound, measured in
1122
- bits-per-dim.
1123
- This term can't be optimized, as it only depends on the encoder.
1124
- :param x_start: the [N x C x ...] tensor of inputs.
1125
- :return: a batch of [N] KL values (in bits), one per batch element.
1126
- """
1127
- batch_size = x_start.shape[0]
1128
- t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1129
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1130
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
1131
- return mean_flat(kl_prior) / np.log(2.0)
1132
-
1133
- def p_losses(self, x_start_0, x_start_1, cond, t, noise=None):
1134
- noise_0 = default(noise, lambda: torch.randn_like(x_start_0))
1135
- x_noisy_0 = self.q_sample(x_start=x_start_0, t=t, noise=noise_0)
1136
- if self.first_stage_downsample:
1137
- x_noisy_1 = None
1138
- else:
1139
- noise_1 = default(noise, lambda: torch.randn_like(x_start_1))
1140
- x_noisy_1 = self.q_sample(x_start=x_start_1, t=t, noise=noise_1)
1141
- model_output_0, model_output_1 = self.apply_model(x_noisy_0, x_noisy_1, t, cond)
1142
-
1143
- loss_dict = {}
1144
- prefix = 'train' if self.training else 'val'
1145
-
1146
- if self.first_stage_downsample:
1147
- target_0 = noise_0
1148
- target_1 = x_start_1
1149
- elif self.parameterization == "x0":
1150
- target_0 = x_start_0
1151
- target_1 = x_start_1
1152
- elif self.parameterization == "eps":
1153
- target_0 = noise_0
1154
- target_1 = noise_1
1155
- else:
1156
- raise NotImplementedError()
1157
-
1158
- loss_simple_0 = self.get_loss(model_output_0, target_0, mean=False).mean([1, 2, 3])
1159
- loss_simple_1 = self.get_loss(model_output_1, target_1, mean=False).mean([1, 2, 3])
1160
- loss_simple = loss_simple_0 + loss_simple_1 * self.mask_loss_factor
1161
-
1162
- loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
1163
- loss_dict.update({f'{prefix}/loss_simple_0': loss_simple_0.mean()})
1164
- loss_dict.update({f'{prefix}/loss_simple_1': loss_simple_1.mean()})
1165
-
1166
- # logvar_t = self.logvar[t].to(self.device)
1167
- # 确保 self.logvar 和 self.device 在同一个设备上
1168
- self.logvar = self.logvar.to(self.device)
1169
- logvar_t = self.logvar[t]
1170
- loss = loss_simple / torch.exp(logvar_t) + logvar_t
1171
-
1172
- if self.learn_logvar:
1173
- loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
1174
- loss_dict.update({'logvar': self.logvar.data.mean()})
1175
-
1176
- loss = self.l_simple_weight * loss.mean()
1177
-
1178
- loss_vlb_0 = self.get_loss(model_output_0, target_0, mean=False).mean(dim=(1, 2, 3))
1179
- loss_vlb_1 = self.get_loss(model_output_1, target_1, mean=False).mean(dim=(1, 2, 3))
1180
- loss_vlb = loss_vlb_0 + loss_vlb_1 * self.mask_loss_factor
1181
-
1182
- loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
1183
- loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
1184
- loss += (self.original_elbo_weight * loss_vlb)
1185
- loss_dict.update({f'{prefix}/loss': loss})
1186
-
1187
- return loss, loss_dict
1188
-
1189
- def p_mean_variance(self, x_0, x_1, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
1190
- return_x0=False, score_corrector=None, corrector_kwargs=None):
1191
- t_in = t
1192
- model_out_0, model_out_1 = self.apply_model(x_0, x_1, t_in, c, return_ids=return_codebook_ids)
1193
-
1194
- if score_corrector is not None:
1195
- assert self.parameterization == "eps"
1196
- model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
1197
-
1198
- if return_codebook_ids:
1199
- model_out, logits = model_out
1200
-
1201
- if self.parameterization == "eps":
1202
- x_recon_0 = self.predict_start_from_noise(x_0, t=t, noise=model_out_0)
1203
- x_recon_1 = self.predict_start_from_noise(x_1, t=t, noise=model_out_1)
1204
- elif self.parameterization == "x0":
1205
- x_recon = model_out
1206
- else:
1207
- raise NotImplementedError()
1208
- if clip_denoised:
1209
- x_recon.clamp_(-1., 1.)
1210
- if quantize_denoised:
1211
- x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1212
-
1213
- model_mean_0, posterior_variance_0, posterior_log_variance_0 = self.q_posterior(x_start=x_recon_0, x_t=x_0, t=t)
1214
- model_mean_1, posterior_variance_1, posterior_log_variance_1 = self.q_posterior(x_start=x_recon_1, x_t=x_1, t=t)
1215
- if return_codebook_ids:
1216
- return model_mean, posterior_variance, posterior_log_variance, logits
1217
- elif return_x0:
1218
- return model_mean, posterior_variance, posterior_log_variance, x_recon
1219
- else:
1220
- return model_mean_0, posterior_variance_0, posterior_log_variance_0, model_mean_1, posterior_variance_1, posterior_log_variance_1
1221
-
1222
- @torch.no_grad()
1223
- def p_sample(self, x_0, x_1, c, t, clip_denoised=False, repeat_noise=False,
1224
- return_codebook_ids=False, quantize_denoised=False, return_x0=False,
1225
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
1226
- b, *_, device = *x_0.shape, x_0.device
1227
- outputs = self.p_mean_variance(x_0=x_0, x_1=x_1, c=c, t=t, clip_denoised=clip_denoised,
1228
- return_codebook_ids=return_codebook_ids,
1229
- quantize_denoised=quantize_denoised,
1230
- return_x0=return_x0,
1231
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1232
-
1233
- if return_codebook_ids:
1234
- raise DeprecationWarning("Support dropped.")
1235
- model_mean, _, model_log_variance, logits = outputs
1236
- elif return_x0:
1237
- model_mean, _, model_log_variance, x0 = outputs
1238
- else:
1239
- model_mean_0, _, model_log_variance_0, model_mean_1, _, model_log_variance_1 = outputs
1240
-
1241
- noise = noise_like(x_0.shape, device, repeat_noise) * temperature
1242
- if noise_dropout > 0.:
1243
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1244
- # no noise when t == 0
1245
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_0.shape) - 1)))
1246
-
1247
- if return_codebook_ids:
1248
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1249
- if return_x0:
1250
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
1251
- else:
1252
- return model_mean_0 + nonzero_mask * (0.5 * model_log_variance_0).exp() * noise, \
1253
- model_mean_1 + nonzero_mask * (0.5 * model_log_variance_1).exp() * noise
1254
-
1255
- @torch.no_grad()
1256
- def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
1257
- img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
1258
- score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
1259
- log_every_t=None):
1260
- if not log_every_t:
1261
- log_every_t = self.log_every_t
1262
- timesteps = self.num_timesteps
1263
- if batch_size is not None:
1264
- b = batch_size if batch_size is not None else shape[0]
1265
- shape = [batch_size] + list(shape)
1266
- else:
1267
- b = batch_size = shape[0]
1268
- if x_T is None:
1269
- img = torch.randn(shape, device=self.device)
1270
- else:
1271
- img = x_T
1272
- intermediates = []
1273
- if cond is not None:
1274
- if isinstance(cond, dict):
1275
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1276
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1277
- else:
1278
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1279
-
1280
- if start_T is not None:
1281
- timesteps = min(timesteps, start_T)
1282
- iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1283
- total=timesteps) if verbose else reversed(
1284
- range(0, timesteps))
1285
- if type(temperature) == float:
1286
- temperature = [temperature] * timesteps
1287
-
1288
- for i in iterator:
1289
- ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1290
- if self.shorten_cond_schedule:
1291
- assert self.model.conditioning_key != 'hybrid'
1292
- tc = self.cond_ids[ts].to(cond.device)
1293
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1294
-
1295
- img, x0_partial = self.p_sample(img, cond, ts,
1296
- clip_denoised=self.clip_denoised,
1297
- quantize_denoised=quantize_denoised, return_x0=True,
1298
- temperature=temperature[i], noise_dropout=noise_dropout,
1299
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1300
- if mask is not None:
1301
- assert x0 is not None
1302
- img_orig = self.q_sample(x0, ts)
1303
- img = img_orig * mask + (1. - mask) * img
1304
-
1305
- if i % log_every_t == 0 or i == timesteps - 1:
1306
- intermediates.append(x0_partial)
1307
- if callback: callback(i)
1308
- if img_callback: img_callback(img, i)
1309
- return img, intermediates
1310
-
1311
- @torch.no_grad()
1312
- def p_sample_loop(self, cond, shape, return_intermediates=False,
1313
- x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1314
- mask=None, x0=None, img_callback=None, start_T=None,
1315
- log_every_t=None):
1316
-
1317
- if not log_every_t:
1318
- log_every_t = self.log_every_t
1319
- device = self.betas.device
1320
- b = shape[0]
1321
-
1322
- if x_T is None:
1323
- img_0 = torch.randn(shape, device=device)
1324
- img_1 = torch.randn(shape, device=device)
1325
- else:
1326
- img= x_T
1327
-
1328
- intermediates = [img_0]
1329
- if timesteps is None:
1330
- timesteps = self.num_timesteps
1331
-
1332
- if start_T is not None:
1333
- timesteps = min(timesteps, start_T)
1334
- iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1335
- range(0, timesteps))
1336
-
1337
- if mask is not None:
1338
- assert x0 is not None
1339
- assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1340
-
1341
- for i in iterator:
1342
- ts = torch.full((b,), i, device=device, dtype=torch.long)
1343
- if self.shorten_cond_schedule:
1344
- assert self.model.conditioning_key != 'hybrid'
1345
- tc = self.cond_ids[ts].to(cond.device)
1346
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1347
-
1348
- img_0, img_1 = self.p_sample(img_0, img_1, cond, ts,
1349
- clip_denoised=self.clip_denoised,
1350
- quantize_denoised=quantize_denoised)
1351
-
1352
- if mask is not None:
1353
- img_orig = self.q_sample(x0, ts)
1354
- img = img_orig * mask + (1. - mask) * img
1355
-
1356
- if i % log_every_t == 0 or i == timesteps - 1:
1357
- intermediates.append(img_0)
1358
- if callback: callback(i)
1359
- if callback: img_callback(img, i)
1360
-
1361
- if return_intermediates:
1362
- return img_0, intermediates
1363
- return img_0
1364
-
1365
- @torch.no_grad()
1366
- def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1367
- verbose=True, timesteps=None, quantize_denoised=False,
1368
- mask=None, x0=None, shape=None,**kwargs):
1369
- if shape is None:
1370
- shape = (batch_size, self.channels, self.image_size, self.image_size)
1371
- if cond is not None:
1372
- if isinstance(cond, dict):
1373
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1374
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1375
- else:
1376
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1377
- return self.p_sample_loop(cond,
1378
- shape,
1379
- return_intermediates=return_intermediates, x_T=x_T,
1380
- verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1381
- mask=mask, x0=x0)
1382
-
1383
- @torch.no_grad()
1384
- def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
1385
-
1386
- if ddim:
1387
- ddim_sampler = DDIMSampler(self)
1388
- shape = (self.channels, self.image_size, self.image_size)
1389
- samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
1390
- shape,cond,verbose=False,**kwargs)
1391
-
1392
- else:
1393
- samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1394
- return_intermediates=True,**kwargs)
1395
-
1396
- return samples, intermediates
1397
-
1398
-
1399
- @torch.no_grad()
1400
- def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1401
- quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
1402
- plot_diffusion_rows=False, **kwargs):
1403
-
1404
- use_ddim = False
1405
-
1406
- log = dict()
1407
- # z_0, z_1, c, x_0, x_0_rec, x_1, x_1_rec, xc
1408
- # z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1409
- z_0, z_1, c, x_0, x_0_rec, x_1, x_1_rec, xc = self.get_input(batch, self.first_stage_key,
1410
- return_first_stage_outputs=True,
1411
- force_c_encode=True,
1412
- return_original_cond=True,
1413
- bs=N, uncond=0)
1414
- N = min(x_0.shape[0], N)
1415
- n_row = min(x_0.shape[0], n_row)
1416
- log["inputs"] = x_0
1417
- log["reals"] = xc["c_concat"]
1418
- log["reconstruction"] = x_0_rec
1419
- if self.model.conditioning_key is not None:
1420
- if hasattr(self.cond_stage_model, "decode"):
1421
- xc = self.cond_stage_model.decode(c)
1422
- log["conditioning"] = xc
1423
- elif self.cond_stage_key in ["caption"]:
1424
- xc = log_txt_as_img((x_0.shape[2], x_0.shape[3]), batch["caption"])
1425
- log["conditioning"] = xc
1426
- elif self.cond_stage_key == 'class_label':
1427
- xc = log_txt_as_img((x_0.shape[2], x_0.shape[3]), batch["human_label"])
1428
- log['conditioning'] = xc
1429
- elif isimage(xc):
1430
- log["conditioning"] = xc
1431
- if ismap(xc):
1432
- log["original_conditioning"] = self.to_rgb(xc)
1433
-
1434
- if plot_diffusion_rows:
1435
- # get diffusion row
1436
- diffusion_row = list()
1437
- z_start = z[:n_row]
1438
- for t in range(self.num_timesteps):
1439
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1440
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1441
- t = t.to(self.device).long()
1442
- noise = torch.randn_like(z_start)
1443
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1444
- diffusion_row.append(self.decode_first_stage(z_noisy))
1445
-
1446
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1447
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1448
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1449
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1450
- log["diffusion_row"] = diffusion_grid
1451
-
1452
- if sample:
1453
- # get denoise row
1454
- with self.ema_scope("Plotting"):
1455
- samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1456
- ddim_steps=ddim_steps,eta=ddim_eta)
1457
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1458
- x_samples = self.decode_first_stage(samples)
1459
- log["samples"] = x_samples
1460
- if plot_denoise_rows:
1461
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1462
- log["denoise_row"] = denoise_grid
1463
-
1464
- if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1465
- self.first_stage_model, IdentityFirstStage):
1466
- # also display when quantizing x0 while sampling
1467
- with self.ema_scope("Plotting Quantized Denoised"):
1468
- samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1469
- ddim_steps=ddim_steps,eta=ddim_eta,
1470
- quantize_denoised=True)
1471
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1472
- # quantize_denoised=True)
1473
- x_samples = self.decode_first_stage(samples.to(self.device))
1474
- log["samples_x0_quantized"] = x_samples
1475
-
1476
- if inpaint:
1477
- # make a simple center square
1478
- b, h, w = z.shape[0], z.shape[2], z.shape[3]
1479
- mask = torch.ones(N, h, w).to(self.device)
1480
- # zeros will be filled in
1481
- mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1482
- mask = mask[:, None, ...]
1483
- with self.ema_scope("Plotting Inpaint"):
1484
-
1485
- samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
1486
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1487
- x_samples = self.decode_first_stage(samples.to(self.device))
1488
- log["samples_inpainting"] = x_samples
1489
- log["mask"] = mask
1490
-
1491
- # outpaint
1492
- with self.ema_scope("Plotting Outpaint"):
1493
- samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
1494
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1495
- x_samples = self.decode_first_stage(samples.to(self.device))
1496
- log["samples_outpainting"] = x_samples
1497
-
1498
- if plot_progressive_rows:
1499
- with self.ema_scope("Plotting Progressives"):
1500
- img, progressives = self.progressive_denoising(c,
1501
- shape=(self.channels, self.image_size, self.image_size),
1502
- batch_size=N)
1503
- prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1504
- log["progressive_row"] = prog_row
1505
-
1506
- if return_keys:
1507
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1508
- return log
1509
- else:
1510
- return {key: log[key] for key in return_keys}
1511
- return log
1512
-
1513
- def configure_optimizers(self):
1514
- lr = self.learning_rate
1515
- params = list(self.model.parameters())
1516
- if self.cond_stage_trainable:
1517
- print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1518
- params = params + list(self.cond_stage_model.parameters())
1519
- if self.learn_logvar:
1520
- print('Diffusion model optimizing logvar')
1521
- params.append(self.logvar)
1522
- opt = torch.optim.AdamW(params, lr=lr)
1523
- if self.use_scheduler:
1524
- assert 'target' in self.scheduler_config
1525
- scheduler = instantiate_from_config(self.scheduler_config)
1526
-
1527
- print("Setting up LambdaLR scheduler...")
1528
- scheduler = [
1529
- {
1530
- 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1531
- 'interval': 'step',
1532
- 'frequency': 1
1533
- }]
1534
- return [opt], scheduler
1535
- return opt
1536
-
1537
- @torch.no_grad()
1538
- def to_rgb(self, x):
1539
- x = x.float()
1540
- if not hasattr(self, "colorize"):
1541
- self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1542
- x = nn.functional.conv2d(x, weight=self.colorize)
1543
- x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1544
- return x
1545
-
1546
-
1547
-
1548
- class DiffusionWrapper(pl.LightningModule):
1549
- def __init__(self, diff_model_config, conditioning_key):
1550
- super().__init__()
1551
- self.diffusion_model = instantiate_from_config(diff_model_config)
1552
- self.conditioning_key = conditioning_key
1553
- assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'hybrid_three_for_mask', 'hybrid_separate_mask_block', 'adm']
1554
-
1555
- def forward(self, x_0, t, x_1, c_concat: list = None, c_crossattn: list = None):
1556
- if self.conditioning_key is None:
1557
- out = self.diffusion_model(x, t)
1558
- elif self.conditioning_key == 'concat':
1559
- xc = torch.cat([x] + c_concat, dim=1)
1560
- out = self.diffusion_model(xc, t)
1561
- elif self.conditioning_key == 'crossattn':
1562
- cc = torch.cat(c_crossattn, 1)
1563
- out = self.diffusion_model(x, t, context=cc)
1564
- elif self.conditioning_key == 'hybrid':
1565
- xc_0 = torch.cat([x_0] + c_concat, dim=1)
1566
- xc_1 = torch.cat([x_1] + c_concat, dim=1)
1567
- cc = torch.cat(c_crossattn, 1)
1568
- out_1, out_2 = self.diffusion_model(xc_0, xc_1, t, context=cc)
1569
- elif self.conditioning_key == 'hybrid_three_for_mask':
1570
- xc_0 = torch.cat([x_0] + c_concat, dim=1)
1571
- xc_1 = torch.cat([x_0, x_1] + c_concat, dim=1)
1572
- cc = torch.cat(c_crossattn, 1)
1573
- out_1, out_2 = self.diffusion_model(xc_0, xc_1, t, context=cc)
1574
- elif self.conditioning_key == 'hybrid_separate_mask_block':
1575
- xc = torch.cat([x_0] + c_concat, dim=1)
1576
- cc = torch.cat(c_crossattn, 1)
1577
- out_1, out_2 = self.diffusion_model(xc, t, context=cc)
1578
- elif self.conditioning_key == 'adm':
1579
- cc = c_crossattn[0]
1580
- out = self.diffusion_model(x, t, y=cc)
1581
- else:
1582
- raise NotImplementedError()
1583
-
1584
- return out_1, out_2
1585
-
1586
-
1587
- class Layout2ImgDiffusion(LatentDiffusion):
1588
- # TODO: move all layout-specific hacks to this class
1589
- def __init__(self, cond_stage_key, *args, **kwargs):
1590
- assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
1591
- super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
1592
-
1593
- def log_images(self, batch, N=8, *args, **kwargs):
1594
- logs = super().log_images(batch=batch, N=N, *args, **kwargs)
1595
-
1596
- key = 'train' if self.training else 'validation'
1597
- dset = self.trainer.datamodule.datasets[key]
1598
- mapper = dset.conditional_builders[self.cond_stage_key]
1599
-
1600
- bbox_imgs = []
1601
- map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
1602
- for tknzd_bbox in batch[self.cond_stage_key][:N]:
1603
- bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
1604
- bbox_imgs.append(bboximg)
1605
-
1606
- cond_img = torch.stack(bbox_imgs, dim=0)
1607
- logs['bbox_image'] = cond_img
1608
- return logs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable_diffusion/ldm/models/diffusion/ddpm_pam_test.py DELETED
@@ -1,1522 +0,0 @@
1
- """
2
- wild mixture of
3
- https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
- https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
- https://github.com/CompVis/taming-transformers
6
- -- merci
7
- """
8
-
9
- # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
10
- # See more details in LICENSE.
11
-
12
- import torch
13
- import torch.nn as nn
14
- import numpy as np
15
- import pytorch_lightning as pl
16
- from torch.optim.lr_scheduler import LambdaLR
17
- from einops import rearrange, repeat
18
- from contextlib import contextmanager
19
- from functools import partial
20
- from tqdm import tqdm
21
- from torchvision.utils import make_grid
22
- from pytorch_lightning.utilities.distributed import rank_zero_only
23
-
24
- from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
25
- from ldm.modules.ema import LitEma
26
- from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
27
- from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
28
- from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
29
- from ldm.models.diffusion.ddim import DDIMSampler
30
-
31
-
32
- __conditioning_keys__ = {'concat': 'c_concat',
33
- 'crossattn': 'c_crossattn',
34
- 'adm': 'y'}
35
-
36
-
37
- def disabled_train(self, mode=True):
38
- """Overwrite model.train with this function to make sure train/eval mode
39
- does not change anymore."""
40
- return self
41
-
42
-
43
- def uniform_on_device(r1, r2, shape, device):
44
- return (r1 - r2) * torch.rand(*shape, device=device) + r2
45
-
46
-
47
- class DDPM(pl.LightningModule):
48
- # classic DDPM with Gaussian diffusion, in image space
49
- def __init__(self,
50
- unet_config,
51
- timesteps=1000,
52
- beta_schedule="linear",
53
- loss_type="l2",
54
- ckpt_path=None,
55
- ignore_keys=[],
56
- load_only_unet=False,
57
- monitor="val/loss",
58
- use_ema=True,
59
- first_stage_key="image",
60
- image_size=256,
61
- channels=3,
62
- log_every_t=100,
63
- clip_denoised=True,
64
- linear_start=1e-4,
65
- linear_end=2e-2,
66
- cosine_s=8e-3,
67
- given_betas=None,
68
- original_elbo_weight=0.,
69
- v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
70
- l_simple_weight=1.,
71
- conditioning_key=None,
72
- parameterization="eps", # all assuming fixed variance schedules
73
- scheduler_config=None,
74
- use_positional_encodings=False,
75
- learn_logvar=False,
76
- logvar_init=0.,
77
- load_ema=True,
78
- ):
79
- super().__init__()
80
- assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
81
- self.parameterization = parameterization
82
- print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
83
- self.cond_stage_model = None
84
- self.clip_denoised = clip_denoised
85
- self.log_every_t = log_every_t
86
- self.first_stage_key = first_stage_key
87
- self.image_size = image_size # try conv?
88
- self.channels = channels
89
- self.use_positional_encodings = use_positional_encodings
90
- self.model = DiffusionWrapper(unet_config, conditioning_key)
91
- count_params(self.model, verbose=True)
92
- self.use_ema = use_ema
93
-
94
- self.use_scheduler = scheduler_config is not None
95
- if self.use_scheduler:
96
- self.scheduler_config = scheduler_config
97
-
98
- self.v_posterior = v_posterior
99
- self.original_elbo_weight = original_elbo_weight
100
- self.l_simple_weight = l_simple_weight
101
-
102
- if monitor is not None:
103
- self.monitor = monitor
104
-
105
- if self.use_ema and load_ema:
106
- self.model_ema = LitEma(self.model)
107
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
108
-
109
- if ckpt_path is not None:
110
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
111
-
112
- # If initialing from EMA-only checkpoint, create EMA model after loading.
113
- if self.use_ema and not load_ema:
114
- self.model_ema = LitEma(self.model)
115
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
116
-
117
- self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
118
- linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
119
-
120
- self.loss_type = loss_type
121
-
122
- self.learn_logvar = learn_logvar
123
- self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
124
- if self.learn_logvar:
125
- self.logvar = nn.Parameter(self.logvar, requires_grad=True)
126
-
127
-
128
- def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
129
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
130
- if exists(given_betas):
131
- betas = given_betas
132
- else:
133
- betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
134
- cosine_s=cosine_s)
135
- alphas = 1. - betas
136
- alphas_cumprod = np.cumprod(alphas, axis=0)
137
- alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
138
-
139
- timesteps, = betas.shape
140
- self.num_timesteps = int(timesteps)
141
- self.linear_start = linear_start
142
- self.linear_end = linear_end
143
- assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
144
-
145
- to_torch = partial(torch.tensor, dtype=torch.float32)
146
-
147
- self.register_buffer('betas', to_torch(betas))
148
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
149
- self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
150
-
151
- # calculations for diffusion q(x_t | x_{t-1}) and others
152
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
153
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
154
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
155
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
156
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
157
-
158
- # calculations for posterior q(x_{t-1} | x_t, x_0)
159
- posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
160
- 1. - alphas_cumprod) + self.v_posterior * betas
161
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
162
- self.register_buffer('posterior_variance', to_torch(posterior_variance))
163
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
164
- self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
165
- self.register_buffer('posterior_mean_coef1', to_torch(
166
- betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
167
- self.register_buffer('posterior_mean_coef2', to_torch(
168
- (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
169
-
170
- if self.parameterization == "eps":
171
- lvlb_weights = self.betas ** 2 / (
172
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
173
- elif self.parameterization == "x0":
174
- lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
175
- else:
176
- raise NotImplementedError("mu not supported")
177
- # TODO how to choose this term
178
- lvlb_weights[0] = lvlb_weights[1]
179
- self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
180
- assert not torch.isnan(self.lvlb_weights).all()
181
-
182
- @contextmanager
183
- def ema_scope(self, context=None):
184
- if self.use_ema:
185
- self.model_ema.store(self.model.parameters())
186
- self.model_ema.copy_to(self.model)
187
- if context is not None:
188
- print(f"{context}: Switched to EMA weights")
189
- try:
190
- yield None
191
- finally:
192
- if self.use_ema:
193
- self.model_ema.restore(self.model.parameters())
194
- if context is not None:
195
- print(f"{context}: Restored training weights")
196
-
197
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
198
- sd = torch.load(path, map_location="cpu")
199
- if "state_dict" in list(sd.keys()):
200
- sd = sd["state_dict"]
201
- keys = list(sd.keys())
202
-
203
- # Our model adds additional channels to the first layer to condition on an input image.
204
- # For the first layer, copy existing channel weights and initialize new channel weights to zero.
205
- input_keys = [
206
- "model.diffusion_model.input_blocks.0.0.weight",
207
- "model_ema.diffusion_modelinput_blocks00weight",
208
- ]
209
-
210
- branch_1_keys = [
211
- "model.diffusion_model.input_blocks_branch_1",
212
- "model.diffusion_model.output_blocks_branch_1",
213
- "model.diffusion_model.out_branch_1",
214
- "model_ema.diffusion_modelinput_blocks_branch_100weight",
215
- "model_ema.diffusion_modelout_branch_10weight",
216
- "model_ema.diffusion_modelout_branch_12weight",
217
-
218
- ]
219
- ignore_keys += branch_1_keys
220
- self_sd = self.state_dict()
221
-
222
-
223
- for input_key in input_keys:
224
- if input_key not in sd or input_key not in self_sd:
225
- continue
226
-
227
- input_weight = self_sd[input_key]
228
-
229
- if input_weight.size() != sd[input_key].size():
230
- print(f"Manual init: {input_key}")
231
- input_weight.zero_()
232
- input_weight[:, :4, :, :].copy_(sd[input_key])
233
- ignore_keys.append(input_key)
234
-
235
-
236
- for branch_1_key in branch_1_keys:
237
- start_with_branch_1_keys = [k for k in self_sd if k.startswith(branch_1_key)]
238
- main_keys = [k.replace("_branch_1", "") for k in start_with_branch_1_keys]
239
-
240
- for start_with_branch_1_key, main_key in zip(start_with_branch_1_keys, main_keys):
241
- if start_with_branch_1_key not in self_sd or main_key not in sd:
242
- continue
243
-
244
- branch_1_weight = self_sd[start_with_branch_1_key]
245
- if branch_1_weight.size() != sd[main_key].size():
246
- print(f"Manual init: {start_with_branch_1_key}")
247
- branch_1_weight.zero_()
248
- branch_1_weight[:, :4, :, :].copy_(sd[main_key])
249
- ignore_keys.append(start_with_branch_1_key)
250
- else:
251
- branch_1_weight.zero_()
252
- branch_1_weight.copy_(sd[main_key])
253
- ignore_keys.append(start_with_branch_1_key)
254
-
255
- for k in keys:
256
- for ik in ignore_keys:
257
- if k.startswith(ik):
258
- print("Deleting key {} from state_dict.".format(k))
259
- del sd[k]
260
-
261
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
262
- sd, strict=False)
263
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
264
- if len(missing) > 0:
265
- print(f"Missing Keys: {missing}")
266
- if len(unexpected) > 0:
267
- print(f"Unexpected Keys: {unexpected}")
268
-
269
-
270
- def q_mean_variance(self, x_start, t):
271
- """
272
- Get the distribution q(x_t | x_0).
273
- :param x_start: the [N x C x ...] tensor of noiseless inputs.
274
- :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
275
- :return: A tuple (mean, variance, log_variance), all of x_start's shape.
276
- """
277
- mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
278
- variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
279
- log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
280
- return mean, variance, log_variance
281
-
282
- def predict_start_from_noise(self, x_t, t, noise):
283
- return (
284
- extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
285
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
286
- )
287
-
288
- def q_posterior(self, x_start, x_t, t):
289
- posterior_mean = (
290
- extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
291
- extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
292
- )
293
- posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
294
- posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
295
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
296
-
297
- def p_mean_variance(self, x, t, clip_denoised: bool):
298
- model_out = self.model(x, t)
299
- if self.parameterization == "eps":
300
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
301
- elif self.parameterization == "x0":
302
- x_recon = model_out
303
- if clip_denoised:
304
- x_recon.clamp_(-1., 1.)
305
-
306
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
307
- return model_mean, posterior_variance, posterior_log_variance
308
-
309
- @torch.no_grad()
310
- def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
311
- b, *_, device = *x.shape, x.device
312
- model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
313
- noise = noise_like(x.shape, device, repeat_noise)
314
- # no noise when t == 0
315
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
316
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
317
-
318
- @torch.no_grad()
319
- def p_sample_loop(self, shape, return_intermediates=False):
320
- device = self.betas.device
321
- b = shape[0]
322
- img = torch.randn(shape, device=device)
323
- intermediates = [img]
324
- for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
325
- img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
326
- clip_denoised=self.clip_denoised)
327
- if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
328
- intermediates.append(img)
329
- if return_intermediates:
330
- return img, intermediates
331
- return img
332
-
333
- @torch.no_grad()
334
- def sample(self, batch_size=16, return_intermediates=False):
335
- image_size = self.image_size
336
- channels = self.channels
337
- return self.p_sample_loop((batch_size, channels, image_size, image_size),
338
- return_intermediates=return_intermediates)
339
-
340
- def q_sample(self, x_start, t, noise=None):
341
- noise = default(noise, lambda: torch.randn_like(x_start))
342
- return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
343
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
344
-
345
- def get_loss(self, pred, target, mean=True):
346
- if self.loss_type == 'l1':
347
- loss = (target - pred).abs()
348
- if mean:
349
- loss = loss.mean()
350
- elif self.loss_type == 'l2':
351
- if mean:
352
- loss = torch.nn.functional.mse_loss(target, pred)
353
- else:
354
- loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
355
- else:
356
- raise NotImplementedError("unknown loss type '{loss_type}'")
357
-
358
- return loss
359
-
360
- def p_losses(self, x_start, t, noise=None):
361
- noise = default(noise, lambda: torch.randn_like(x_start))
362
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
363
- model_out = self.model(x_noisy, t)
364
-
365
- loss_dict = {}
366
- if self.parameterization == "eps":
367
- target = noise
368
- elif self.parameterization == "x0":
369
- target = x_start
370
- else:
371
- raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
372
-
373
- loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
374
-
375
- log_prefix = 'train' if self.training else 'val'
376
-
377
- loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
378
- loss_simple = loss.mean() * self.l_simple_weight
379
-
380
- loss_vlb = (self.lvlb_weights[t] * loss).mean()
381
- loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
382
-
383
- loss = loss_simple + self.original_elbo_weight * loss_vlb
384
-
385
- loss_dict.update({f'{log_prefix}/loss': loss})
386
-
387
- return loss, loss_dict
388
-
389
- def forward(self, x, *args, **kwargs):
390
- # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
391
- # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
392
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
393
- return self.p_losses(x, t, *args, **kwargs)
394
-
395
- def get_input(self, batch, k):
396
- return batch[k]
397
-
398
- def shared_step(self, batch):
399
- x = self.get_input(batch, self.first_stage_key)
400
- loss, loss_dict = self(x)
401
- return loss, loss_dict
402
-
403
- def training_step(self, batch, batch_idx):
404
- loss, loss_dict = self.shared_step(batch)
405
-
406
- self.log_dict(loss_dict, prog_bar=True,
407
- logger=True, on_step=True, on_epoch=True)
408
-
409
- self.log("global_step", self.global_step,
410
- prog_bar=True, logger=True, on_step=True, on_epoch=False)
411
-
412
- if self.use_scheduler:
413
- lr = self.optimizers().param_groups[0]['lr']
414
- self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
415
-
416
- return loss
417
-
418
- @torch.no_grad()
419
- def validation_step(self, batch, batch_idx):
420
- _, loss_dict_no_ema = self.shared_step(batch)
421
- with self.ema_scope():
422
- _, loss_dict_ema = self.shared_step(batch)
423
- loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
424
- self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
425
- self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
426
-
427
- def on_train_batch_end(self, *args, **kwargs):
428
- if self.use_ema:
429
- self.model_ema(self.model)
430
-
431
- def _get_rows_from_list(self, samples):
432
- n_imgs_per_row = len(samples)
433
- denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
434
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
435
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
436
- return denoise_grid
437
-
438
- @torch.no_grad()
439
- def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
440
- log = dict()
441
- x = self.get_input(batch, self.first_stage_key)
442
- N = min(x.shape[0], N)
443
- n_row = min(x.shape[0], n_row)
444
- x = x.to(self.device)[:N]
445
- log["inputs"] = x
446
-
447
- # get diffusion row
448
- diffusion_row = list()
449
- x_start = x[:n_row]
450
-
451
- for t in range(self.num_timesteps):
452
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
453
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
454
- t = t.to(self.device).long()
455
- noise = torch.randn_like(x_start)
456
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
457
- diffusion_row.append(x_noisy)
458
-
459
- log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
460
-
461
- if sample:
462
- # get denoise row
463
- with self.ema_scope("Plotting"):
464
- samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
465
-
466
- log["samples"] = samples
467
- log["denoise_row"] = self._get_rows_from_list(denoise_row)
468
-
469
- if return_keys:
470
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
471
- return log
472
- else:
473
- return {key: log[key] for key in return_keys}
474
- return log
475
-
476
- def configure_optimizers(self):
477
- lr = self.learning_rate
478
- params = list(self.model.parameters())
479
- if self.learn_logvar:
480
- params = params + [self.logvar]
481
- opt = torch.optim.AdamW(params, lr=lr)
482
- return opt
483
-
484
-
485
- class LatentDiffusion(DDPM):
486
- """main class"""
487
- def __init__(self,
488
- first_stage_config,
489
- cond_stage_config,
490
- num_timesteps_cond=None,
491
- cond_stage_key="image",
492
- cond_stage_trainable=False,
493
- concat_mode=True,
494
- cond_stage_forward=None,
495
- conditioning_key=None,
496
- scale_factor=1.0,
497
- scale_by_std=False,
498
- load_ema=True,
499
- *args, **kwargs):
500
- self.num_timesteps_cond = default(num_timesteps_cond, 1)
501
- self.scale_by_std = scale_by_std
502
- assert self.num_timesteps_cond <= kwargs['timesteps']
503
- # for backwards compatibility after implementation of DiffusionWrapper
504
- if conditioning_key is None:
505
- conditioning_key = 'concat' if concat_mode else 'crossattn'
506
- if cond_stage_config == '__is_unconditional__':
507
- conditioning_key = None
508
- ckpt_path = kwargs.pop("ckpt_path", None)
509
- ignore_keys = kwargs.pop("ignore_keys", [])
510
- super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
511
- self.concat_mode = concat_mode
512
- self.cond_stage_trainable = cond_stage_trainable
513
- self.cond_stage_key = cond_stage_key
514
- try:
515
- self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
516
- except:
517
- self.num_downs = 0
518
- if not scale_by_std:
519
- self.scale_factor = scale_factor
520
- else:
521
- self.register_buffer('scale_factor', torch.tensor(scale_factor))
522
- self.instantiate_first_stage(first_stage_config)
523
- self.instantiate_cond_stage(cond_stage_config)
524
- self.cond_stage_forward = cond_stage_forward
525
- self.clip_denoised = False
526
- self.bbox_tokenizer = None
527
-
528
- self.restarted_from_ckpt = False
529
- if ckpt_path is not None:
530
- self.init_from_ckpt(ckpt_path, ignore_keys)
531
- self.restarted_from_ckpt = True
532
-
533
- if self.use_ema and not load_ema:
534
- self.model_ema = LitEma(self.model)
535
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
536
-
537
- def make_cond_schedule(self, ):
538
- self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
539
- ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
540
- self.cond_ids[:self.num_timesteps_cond] = ids
541
-
542
- @rank_zero_only
543
- @torch.no_grad()
544
- def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
545
- # only for very first batch
546
- if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
547
- assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
548
- # set rescale weight to 1./std of encodings
549
- print("### USING STD-RESCALING ###")
550
- x = super().get_input(batch, self.first_stage_key)
551
- x = x.to(self.device)
552
- encoder_posterior = self.encode_first_stage(x)
553
- z = self.get_first_stage_encoding(encoder_posterior).detach()
554
- del self.scale_factor
555
- self.register_buffer('scale_factor', 1. / z.flatten().std())
556
- print(f"setting self.scale_factor to {self.scale_factor}")
557
- print("### USING STD-RESCALING ###")
558
-
559
- def register_schedule(self,
560
- given_betas=None, beta_schedule="linear", timesteps=1000,
561
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
562
- super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
563
-
564
- self.shorten_cond_schedule = self.num_timesteps_cond > 1
565
- if self.shorten_cond_schedule:
566
- self.make_cond_schedule()
567
-
568
- def instantiate_first_stage(self, config):
569
- model = instantiate_from_config(config)
570
- self.first_stage_model = model.eval()
571
- self.first_stage_model.train = disabled_train
572
- for param in self.first_stage_model.parameters():
573
- param.requires_grad = False
574
-
575
- def instantiate_cond_stage(self, config):
576
- if not self.cond_stage_trainable:
577
- if config == "__is_first_stage__":
578
- print("Using first stage also as cond stage.")
579
- self.cond_stage_model = self.first_stage_model
580
- elif config == "__is_unconditional__":
581
- print(f"Training {self.__class__.__name__} as an unconditional model.")
582
- self.cond_stage_model = None
583
- # self.be_unconditional = True
584
- else:
585
- model = instantiate_from_config(config)
586
- self.cond_stage_model = model.eval()
587
- self.cond_stage_model.train = disabled_train
588
- for param in self.cond_stage_model.parameters():
589
- param.requires_grad = False
590
- else:
591
- assert config != '__is_first_stage__'
592
- assert config != '__is_unconditional__'
593
- model = instantiate_from_config(config)
594
- self.cond_stage_model = model
595
-
596
- def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
597
- denoise_row = []
598
- for zd in tqdm(samples, desc=desc):
599
- denoise_row.append(self.decode_first_stage(zd.to(self.device),
600
- force_not_quantize=force_no_decoder_quantization))
601
- n_imgs_per_row = len(denoise_row)
602
- denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
603
- denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
604
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
605
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
606
- return denoise_grid
607
-
608
- def get_first_stage_encoding(self, encoder_posterior):
609
- if isinstance(encoder_posterior, DiagonalGaussianDistribution):
610
- z = encoder_posterior.sample()
611
- elif isinstance(encoder_posterior, torch.Tensor):
612
- z = encoder_posterior
613
- else:
614
- raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
615
- return self.scale_factor * z
616
-
617
- def get_learned_conditioning(self, c):
618
- if self.cond_stage_forward is None:
619
- if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
620
- c = self.cond_stage_model.encode(c)
621
- if isinstance(c, DiagonalGaussianDistribution):
622
- c = c.mode()
623
- else:
624
- c = self.cond_stage_model(c)
625
- else:
626
- assert hasattr(self.cond_stage_model, self.cond_stage_forward)
627
- c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
628
- return c
629
-
630
- def meshgrid(self, h, w):
631
- y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
632
- x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
633
-
634
- arr = torch.cat([y, x], dim=-1)
635
- return arr
636
-
637
- def delta_border(self, h, w):
638
- """
639
- :param h: height
640
- :param w: width
641
- :return: normalized distance to image border,
642
- wtith min distance = 0 at border and max dist = 0.5 at image center
643
- """
644
- lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
645
- arr = self.meshgrid(h, w) / lower_right_corner
646
- dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
647
- dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
648
- edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
649
- return edge_dist
650
-
651
- def get_weighting(self, h, w, Ly, Lx, device):
652
- weighting = self.delta_border(h, w)
653
- weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
654
- self.split_input_params["clip_max_weight"], )
655
- weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
656
-
657
- if self.split_input_params["tie_braker"]:
658
- L_weighting = self.delta_border(Ly, Lx)
659
- L_weighting = torch.clip(L_weighting,
660
- self.split_input_params["clip_min_tie_weight"],
661
- self.split_input_params["clip_max_tie_weight"])
662
-
663
- L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
664
- weighting = weighting * L_weighting
665
- return weighting
666
-
667
- def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
668
- """
669
- :param x: img of size (bs, c, h, w)
670
- :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
671
- """
672
- bs, nc, h, w = x.shape
673
-
674
- # number of crops in image
675
- Ly = (h - kernel_size[0]) // stride[0] + 1
676
- Lx = (w - kernel_size[1]) // stride[1] + 1
677
-
678
- if uf == 1 and df == 1:
679
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
680
- unfold = torch.nn.Unfold(**fold_params)
681
-
682
- fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
683
-
684
- weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
685
- normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
686
- weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
687
-
688
- elif uf > 1 and df == 1:
689
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
690
- unfold = torch.nn.Unfold(**fold_params)
691
-
692
- fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
693
- dilation=1, padding=0,
694
- stride=(stride[0] * uf, stride[1] * uf))
695
- fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
696
-
697
- weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
698
- normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
699
- weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
700
-
701
- elif df > 1 and uf == 1:
702
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
703
- unfold = torch.nn.Unfold(**fold_params)
704
-
705
- fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
706
- dilation=1, padding=0,
707
- stride=(stride[0] // df, stride[1] // df))
708
- fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
709
-
710
- weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
711
- normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
712
- weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
713
-
714
- else:
715
- raise NotImplementedError
716
-
717
- return fold, unfold, normalization, weighting
718
-
719
- @torch.no_grad()
720
- def get_input(self, batch, keys, return_first_stage_outputs=False, force_c_encode=False,
721
- cond_key=None, return_original_cond=False, bs=None, uncond=0.05):
722
- x_0 = super().get_input(batch, keys[0])
723
- x_1 = super().get_input(batch, keys[1])
724
- if bs is not None:
725
- x_0 = x_0[:bs]
726
- x_1 = x_1[:bs]
727
- x_0 = x_0.to(self.device)
728
- x_1 = x_1.to(self.device)
729
- encoder_posterior = self.encode_first_stage(x_0)
730
- z_0 = self.get_first_stage_encoding(encoder_posterior).detach()
731
- z_1 = self.get_first_stage_encoding(self.encode_first_stage(x_1)).detach()
732
- cond_key = cond_key or self.cond_stage_key
733
- xc = super().get_input(batch, cond_key)
734
- if bs is not None:
735
- xc["c_crossattn"] = xc["c_crossattn"][:bs]
736
- xc["c_concat"] = xc["c_concat"][:bs]
737
- cond = {}
738
-
739
- # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%.
740
- random = torch.rand(x_0.size(0), device=x_0.device)
741
- prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1")
742
- input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1")
743
-
744
- null_prompt = self.get_learned_conditioning([""])
745
- cond["c_crossattn"] = [torch.where(prompt_mask, null_prompt, self.get_learned_conditioning(xc["c_crossattn"]).detach())]
746
- cond["c_concat"] = [input_mask * self.encode_first_stage((xc["c_concat"].to(self.device))).mode().detach()]
747
-
748
- out = [z_0, z_1, cond]
749
- if return_first_stage_outputs:
750
- x_0_rec = self.decode_first_stage(z_0)
751
- x_1_rec = self.decode_first_stage(z_1)
752
- out.extend([x_0, x_0_rec, x_1, x_1_rec])
753
- if return_original_cond:
754
- out.append(xc)
755
-
756
- return out
757
-
758
- @torch.no_grad()
759
- def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
760
- if predict_cids:
761
- if z.dim() == 4:
762
- z = torch.argmax(z.exp(), dim=1).long()
763
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
764
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
765
-
766
- z = 1. / self.scale_factor * z
767
-
768
- if hasattr(self, "split_input_params"):
769
- if self.split_input_params["patch_distributed_vq"]:
770
- ks = self.split_input_params["ks"] # eg. (128, 128)
771
- stride = self.split_input_params["stride"] # eg. (64, 64)
772
- uf = self.split_input_params["vqf"]
773
- bs, nc, h, w = z.shape
774
- if ks[0] > h or ks[1] > w:
775
- ks = (min(ks[0], h), min(ks[1], w))
776
- print("reducing Kernel")
777
-
778
- if stride[0] > h or stride[1] > w:
779
- stride = (min(stride[0], h), min(stride[1], w))
780
- print("reducing stride")
781
-
782
- fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
783
-
784
- z = unfold(z) # (bn, nc * prod(**ks), L)
785
- # 1. Reshape to img shape
786
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
787
-
788
- # 2. apply model loop over last dim
789
- if isinstance(self.first_stage_model, VQModelInterface):
790
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
791
- force_not_quantize=predict_cids or force_not_quantize)
792
- for i in range(z.shape[-1])]
793
- else:
794
-
795
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
796
- for i in range(z.shape[-1])]
797
-
798
- o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
799
- o = o * weighting
800
- # Reverse 1. reshape to img shape
801
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
802
- # stitch crops together
803
- decoded = fold(o)
804
- decoded = decoded / normalization # norm is shape (1, 1, h, w)
805
- return decoded
806
- else:
807
- if isinstance(self.first_stage_model, VQModelInterface):
808
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
809
- else:
810
- return self.first_stage_model.decode(z)
811
-
812
- else:
813
- if isinstance(self.first_stage_model, VQModelInterface):
814
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
815
- else:
816
- return self.first_stage_model.decode(z)
817
-
818
- # same as above but without decorator
819
- def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
820
- if predict_cids:
821
- if z.dim() == 4:
822
- z = torch.argmax(z.exp(), dim=1).long()
823
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
824
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
825
-
826
- z = 1. / self.scale_factor * z
827
-
828
- if hasattr(self, "split_input_params"):
829
- if self.split_input_params["patch_distributed_vq"]:
830
- ks = self.split_input_params["ks"] # eg. (128, 128)
831
- stride = self.split_input_params["stride"] # eg. (64, 64)
832
- uf = self.split_input_params["vqf"]
833
- bs, nc, h, w = z.shape
834
- if ks[0] > h or ks[1] > w:
835
- ks = (min(ks[0], h), min(ks[1], w))
836
- print("reducing Kernel")
837
-
838
- if stride[0] > h or stride[1] > w:
839
- stride = (min(stride[0], h), min(stride[1], w))
840
- print("reducing stride")
841
-
842
- fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
843
-
844
- z = unfold(z) # (bn, nc * prod(**ks), L)
845
- # 1. Reshape to img shape
846
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
847
-
848
- # 2. apply model loop over last dim
849
- if isinstance(self.first_stage_model, VQModelInterface):
850
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
851
- force_not_quantize=predict_cids or force_not_quantize)
852
- for i in range(z.shape[-1])]
853
- else:
854
-
855
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
856
- for i in range(z.shape[-1])]
857
-
858
- o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
859
- o = o * weighting
860
- # Reverse 1. reshape to img shape
861
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
862
- # stitch crops together
863
- decoded = fold(o)
864
- decoded = decoded / normalization # norm is shape (1, 1, h, w)
865
- return decoded
866
- else:
867
- if isinstance(self.first_stage_model, VQModelInterface):
868
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
869
- else:
870
- return self.first_stage_model.decode(z)
871
-
872
- else:
873
- if isinstance(self.first_stage_model, VQModelInterface):
874
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
875
- else:
876
- return self.first_stage_model.decode(z)
877
-
878
- @torch.no_grad()
879
- def encode_first_stage(self, x):
880
- if hasattr(self, "split_input_params"):
881
- if self.split_input_params["patch_distributed_vq"]:
882
- ks = self.split_input_params["ks"] # eg. (128, 128)
883
- stride = self.split_input_params["stride"] # eg. (64, 64)
884
- df = self.split_input_params["vqf"]
885
- self.split_input_params['original_image_size'] = x.shape[-2:]
886
- bs, nc, h, w = x.shape
887
- if ks[0] > h or ks[1] > w:
888
- ks = (min(ks[0], h), min(ks[1], w))
889
- print("reducing Kernel")
890
-
891
- if stride[0] > h or stride[1] > w:
892
- stride = (min(stride[0], h), min(stride[1], w))
893
- print("reducing stride")
894
-
895
- fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
896
- z = unfold(x) # (bn, nc * prod(**ks), L)
897
- # Reshape to img shape
898
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
899
-
900
- output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
901
- for i in range(z.shape[-1])]
902
-
903
- o = torch.stack(output_list, axis=-1)
904
- o = o * weighting
905
-
906
- # Reverse reshape to img shape
907
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
908
- # stitch crops together
909
- decoded = fold(o)
910
- decoded = decoded / normalization
911
- return decoded
912
-
913
- else:
914
- return self.first_stage_model.encode(x)
915
- else:
916
- return self.first_stage_model.encode(x)
917
-
918
- def shared_step(self, batch, **kwargs):
919
- x_0, x_1, c = self.get_input(batch, self.first_stage_key)
920
- loss = self(x_0, x_1, c)
921
- return loss
922
-
923
- def forward(self, x_0, x_1, c, *args, **kwargs):
924
- t = torch.randint(0, self.num_timesteps, (x_0.shape[0],), device=self.device).long()
925
- if self.model.conditioning_key is not None:
926
- assert c is not None
927
- # in pix2pix, cond_stage_trainable and short_cond_schedule are false
928
- if self.cond_stage_trainable:
929
- c = self.get_learned_conditioning(c)
930
- if self.shorten_cond_schedule: # TODO: drop this option
931
- tc = self.cond_ids[t].to(self.device)
932
- c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
933
- return self.p_losses(x_0, x_1, c, t, *args, **kwargs)
934
-
935
- def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
936
- def rescale_bbox(bbox):
937
- x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
938
- y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
939
- w = min(bbox[2] / crop_coordinates[2], 1 - x0)
940
- h = min(bbox[3] / crop_coordinates[3], 1 - y0)
941
- return x0, y0, w, h
942
-
943
- return [rescale_bbox(b) for b in bboxes]
944
-
945
- def apply_model(self, x_noisy_0, x_noisy_1, t, cond, return_ids=False):
946
- if isinstance(cond, dict):
947
- # hybrid case, cond is exptected to be a dict
948
- pass
949
- else:
950
- if not isinstance(cond, list):
951
- cond = [cond]
952
- key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
953
- cond = {key: cond}
954
-
955
- if hasattr(self, "split_input_params"):
956
- assert len(cond) == 1 # todo can only deal with one conditioning atm
957
- assert not return_ids
958
- ks = self.split_input_params["ks"] # eg. (128, 128)
959
- stride = self.split_input_params["stride"] # eg. (64, 64)
960
-
961
- h, w = x_noisy.shape[-2:]
962
-
963
- fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
964
-
965
- z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
966
- # Reshape to img shape
967
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
968
- z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
969
-
970
- if self.cond_stage_key in ["image", "LR_image", "segmentation",
971
- 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
972
- c_key = next(iter(cond.keys())) # get key
973
- c = next(iter(cond.values())) # get value
974
- assert (len(c) == 1) # todo extend to list with more than one elem
975
- c = c[0] # get element
976
-
977
- c = unfold(c)
978
- c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
979
-
980
- cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
981
-
982
- elif self.cond_stage_key == 'coordinates_bbox':
983
- assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
984
-
985
- # assuming padding of unfold is always 0 and its dilation is always 1
986
- n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
987
- full_img_h, full_img_w = self.split_input_params['original_image_size']
988
- # as we are operating on latents, we need the factor from the original image size to the
989
- # spatial latent size to properly rescale the crops for regenerating the bbox annotations
990
- num_downs = self.first_stage_model.encoder.num_resolutions - 1
991
- rescale_latent = 2 ** (num_downs)
992
-
993
- # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
994
- # need to rescale the tl patch coordinates to be in between (0,1)
995
- tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
996
- rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
997
- for patch_nr in range(z.shape[-1])]
998
-
999
- # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
1000
- patch_limits = [(x_tl, y_tl,
1001
- rescale_latent * ks[0] / full_img_w,
1002
- rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
1003
- # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
1004
-
1005
- # tokenize crop coordinates for the bounding boxes of the respective patches
1006
- patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
1007
- for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
1008
- print(patch_limits_tknzd[0].shape)
1009
- # cut tknzd crop position from conditioning
1010
- assert isinstance(cond, dict), 'cond must be dict to be fed into model'
1011
- cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
1012
- print(cut_cond.shape)
1013
-
1014
- adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
1015
- adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
1016
- print(adapted_cond.shape)
1017
- adapted_cond = self.get_learned_conditioning(adapted_cond)
1018
- print(adapted_cond.shape)
1019
- adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
1020
- print(adapted_cond.shape)
1021
-
1022
- cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
1023
-
1024
- else:
1025
- cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
1026
-
1027
- # apply model by loop over crops
1028
- output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
1029
- assert not isinstance(output_list[0],
1030
- tuple) # todo cant deal with multiple model outputs check this never happens
1031
-
1032
- o = torch.stack(output_list, axis=-1)
1033
- o = o * weighting
1034
- # Reverse reshape to img shape
1035
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
1036
- # stitch crops together
1037
- x_recon = fold(o) / normalization
1038
-
1039
- else:
1040
- x_recon_0, x_recon_1 = self.model(x_noisy_0, x_noisy_1, t, **cond)
1041
-
1042
- if isinstance(x_recon_0, tuple) and not return_ids:
1043
- return x_recon_0[0], x_recon_1[0]
1044
- else:
1045
- return x_recon_0, x_recon_1
1046
-
1047
- def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
1048
- return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
1049
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
1050
-
1051
- def _prior_bpd(self, x_start):
1052
- """
1053
- Get the prior KL term for the variational lower-bound, measured in
1054
- bits-per-dim.
1055
- This term can't be optimized, as it only depends on the encoder.
1056
- :param x_start: the [N x C x ...] tensor of inputs.
1057
- :return: a batch of [N] KL values (in bits), one per batch element.
1058
- """
1059
- batch_size = x_start.shape[0]
1060
- t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1061
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1062
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
1063
- return mean_flat(kl_prior) / np.log(2.0)
1064
-
1065
- def p_losses(self, x_start_0, x_start_1, cond, t, noise=None):
1066
- noise_0 = default(noise, lambda: torch.randn_like(x_start_0))
1067
- noise_1 = default(noise, lambda: torch.randn_like(x_start_1))
1068
- x_noisy_0 = self.q_sample(x_start=x_start_0, t=t, noise=noise_0)
1069
- x_noisy_1 = self.q_sample(x_start=x_start_1, t=t, noise=noise_1)
1070
- model_output_0, model_output_1 = self.apply_model(x_noisy_0, x_noisy_1, t, cond)
1071
-
1072
- loss_dict = {}
1073
- prefix = 'train' if self.training else 'val'
1074
-
1075
- if self.parameterization == "x0":
1076
- target_0 = x_start_0
1077
- target_1 = x_start_1
1078
- elif self.parameterization == "eps":
1079
- target_0 = noise_0
1080
- target_1 = noise_1
1081
- else:
1082
- raise NotImplementedError()
1083
-
1084
- loss_simple_0 = self.get_loss(model_output_0, target_0, mean=False).mean([1, 2, 3])
1085
- loss_simple_1 = self.get_loss(model_output_1, target_1, mean=False).mean([1, 2, 3])
1086
- loss_simple = (loss_simple_0 + loss_simple_1) / 2
1087
-
1088
- loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
1089
-
1090
- # logvar_t = self.logvar[t].to(self.device)
1091
- # 确保 self.logvar 和 self.device 在同一个设备上
1092
- self.logvar = self.logvar.to(self.device)
1093
- logvar_t = self.logvar[t]
1094
- loss = loss_simple / torch.exp(logvar_t) + logvar_t
1095
-
1096
- if self.learn_logvar:
1097
- loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
1098
- loss_dict.update({'logvar': self.logvar.data.mean()})
1099
-
1100
- loss = self.l_simple_weight * loss.mean()
1101
-
1102
- loss_vlb_0 = self.get_loss(model_output_0, target_0, mean=False).mean(dim=(1, 2, 3))
1103
- loss_vlb_1 = self.get_loss(model_output_1, target_1, mean=False).mean(dim=(1, 2, 3))
1104
- loss_vlb = (loss_vlb_0 + loss_vlb_1) / 2
1105
-
1106
- loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
1107
- loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
1108
- loss += (self.original_elbo_weight * loss_vlb)
1109
- loss_dict.update({f'{prefix}/loss': loss})
1110
-
1111
- return loss, loss_dict
1112
-
1113
- def p_mean_variance(self, x_0, x_1, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
1114
- return_x0=False, score_corrector=None, corrector_kwargs=None):
1115
- t_in = t
1116
- model_out_0, model_out_1 = self.apply_model(x_0, x_1, t_in, c, return_ids=return_codebook_ids)
1117
-
1118
- if score_corrector is not None:
1119
- assert self.parameterization == "eps"
1120
- model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
1121
-
1122
- if return_codebook_ids:
1123
- model_out, logits = model_out
1124
-
1125
- if self.parameterization == "eps":
1126
- x_recon_0 = self.predict_start_from_noise(x_0, t=t, noise=model_out_0)
1127
- x_recon_1 = self.predict_start_from_noise(x_1, t=t, noise=model_out_1)
1128
- elif self.parameterization == "x0":
1129
- x_recon = model_out
1130
- else:
1131
- raise NotImplementedError()
1132
- if clip_denoised:
1133
- x_recon.clamp_(-1., 1.)
1134
- if quantize_denoised:
1135
- x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1136
-
1137
- model_mean_0, posterior_variance_0, posterior_log_variance_0 = self.q_posterior(x_start=x_recon_0, x_t=x_0, t=t)
1138
- model_mean_1, posterior_variance_1, posterior_log_variance_1 = self.q_posterior(x_start=x_recon_1, x_t=x_1, t=t)
1139
- if return_codebook_ids:
1140
- return model_mean, posterior_variance, posterior_log_variance, logits
1141
- elif return_x0:
1142
- return model_mean, posterior_variance, posterior_log_variance, x_recon
1143
- else:
1144
- return model_mean_0, posterior_variance_0, posterior_log_variance_0, model_mean_1, posterior_variance_1, posterior_log_variance_1
1145
-
1146
- @torch.no_grad()
1147
- def p_sample(self, x_0, x_1, c, t, clip_denoised=False, repeat_noise=False,
1148
- return_codebook_ids=False, quantize_denoised=False, return_x0=False,
1149
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
1150
- b, *_, device = *x_0.shape, x_0.device
1151
- outputs = self.p_mean_variance(x_0=x_0, x_1=x_1, c=c, t=t, clip_denoised=clip_denoised,
1152
- return_codebook_ids=return_codebook_ids,
1153
- quantize_denoised=quantize_denoised,
1154
- return_x0=return_x0,
1155
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1156
-
1157
- if return_codebook_ids:
1158
- raise DeprecationWarning("Support dropped.")
1159
- model_mean, _, model_log_variance, logits = outputs
1160
- elif return_x0:
1161
- model_mean, _, model_log_variance, x0 = outputs
1162
- else:
1163
- model_mean_0, _, model_log_variance_0, model_mean_1, _, model_log_variance_1 = outputs
1164
-
1165
- noise = noise_like(x_0.shape, device, repeat_noise) * temperature
1166
- if noise_dropout > 0.:
1167
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1168
- # no noise when t == 0
1169
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_0.shape) - 1)))
1170
-
1171
- if return_codebook_ids:
1172
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1173
- if return_x0:
1174
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
1175
- else:
1176
- return model_mean_0 + nonzero_mask * (0.5 * model_log_variance_0).exp() * noise, \
1177
- model_mean_1 + nonzero_mask * (0.5 * model_log_variance_1).exp() * noise
1178
-
1179
- @torch.no_grad()
1180
- def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
1181
- img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
1182
- score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
1183
- log_every_t=None):
1184
- if not log_every_t:
1185
- log_every_t = self.log_every_t
1186
- timesteps = self.num_timesteps
1187
- if batch_size is not None:
1188
- b = batch_size if batch_size is not None else shape[0]
1189
- shape = [batch_size] + list(shape)
1190
- else:
1191
- b = batch_size = shape[0]
1192
- if x_T is None:
1193
- img = torch.randn(shape, device=self.device)
1194
- else:
1195
- img = x_T
1196
- intermediates = []
1197
- if cond is not None:
1198
- if isinstance(cond, dict):
1199
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1200
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1201
- else:
1202
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1203
-
1204
- if start_T is not None:
1205
- timesteps = min(timesteps, start_T)
1206
- iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1207
- total=timesteps) if verbose else reversed(
1208
- range(0, timesteps))
1209
- if type(temperature) == float:
1210
- temperature = [temperature] * timesteps
1211
-
1212
- for i in iterator:
1213
- ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1214
- if self.shorten_cond_schedule:
1215
- assert self.model.conditioning_key != 'hybrid'
1216
- tc = self.cond_ids[ts].to(cond.device)
1217
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1218
-
1219
- img, x0_partial = self.p_sample(img, cond, ts,
1220
- clip_denoised=self.clip_denoised,
1221
- quantize_denoised=quantize_denoised, return_x0=True,
1222
- temperature=temperature[i], noise_dropout=noise_dropout,
1223
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1224
- if mask is not None:
1225
- assert x0 is not None
1226
- img_orig = self.q_sample(x0, ts)
1227
- img = img_orig * mask + (1. - mask) * img
1228
-
1229
- if i % log_every_t == 0 or i == timesteps - 1:
1230
- intermediates.append(x0_partial)
1231
- if callback: callback(i)
1232
- if img_callback: img_callback(img, i)
1233
- return img, intermediates
1234
-
1235
- @torch.no_grad()
1236
- def p_sample_loop(self, cond, shape, return_intermediates=False,
1237
- x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1238
- mask=None, x0=None, img_callback=None, start_T=None,
1239
- log_every_t=None):
1240
-
1241
- if not log_every_t:
1242
- log_every_t = self.log_every_t
1243
- device = self.betas.device
1244
- b = shape[0]
1245
-
1246
- if x_T is None:
1247
- img_0 = torch.randn(shape, device=device)
1248
- img_1 = torch.randn(shape, device=device)
1249
- else:
1250
- img= x_T
1251
-
1252
- intermediates = [img_0]
1253
- if timesteps is None:
1254
- timesteps = self.num_timesteps
1255
-
1256
- if start_T is not None:
1257
- timesteps = min(timesteps, start_T)
1258
- iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1259
- range(0, timesteps))
1260
-
1261
- if mask is not None:
1262
- assert x0 is not None
1263
- assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1264
-
1265
- for i in iterator:
1266
- ts = torch.full((b,), i, device=device, dtype=torch.long)
1267
- if self.shorten_cond_schedule:
1268
- assert self.model.conditioning_key != 'hybrid'
1269
- tc = self.cond_ids[ts].to(cond.device)
1270
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1271
-
1272
- img_0, img_1 = self.p_sample(img_0, img_1, cond, ts,
1273
- clip_denoised=self.clip_denoised,
1274
- quantize_denoised=quantize_denoised)
1275
-
1276
- if mask is not None:
1277
- img_orig = self.q_sample(x0, ts)
1278
- img = img_orig * mask + (1. - mask) * img
1279
-
1280
- if i % log_every_t == 0 or i == timesteps - 1:
1281
- intermediates.append(img_0)
1282
- if callback: callback(i)
1283
- if callback: img_callback(img, i)
1284
-
1285
- if return_intermediates:
1286
- return img_0, intermediates
1287
- return img_0
1288
-
1289
- @torch.no_grad()
1290
- def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1291
- verbose=True, timesteps=None, quantize_denoised=False,
1292
- mask=None, x0=None, shape=None,**kwargs):
1293
- if shape is None:
1294
- shape = (batch_size, self.channels, self.image_size, self.image_size)
1295
- if cond is not None:
1296
- if isinstance(cond, dict):
1297
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1298
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1299
- else:
1300
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1301
- return self.p_sample_loop(cond,
1302
- shape,
1303
- return_intermediates=return_intermediates, x_T=x_T,
1304
- verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1305
- mask=mask, x0=x0)
1306
-
1307
- @torch.no_grad()
1308
- def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
1309
-
1310
- if ddim:
1311
- ddim_sampler = DDIMSampler(self)
1312
- shape = (self.channels, self.image_size, self.image_size)
1313
- samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
1314
- shape,cond,verbose=False,**kwargs)
1315
-
1316
- else:
1317
- samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1318
- return_intermediates=True,**kwargs)
1319
-
1320
- return samples, intermediates
1321
-
1322
-
1323
- @torch.no_grad()
1324
- def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1325
- quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
1326
- plot_diffusion_rows=False, **kwargs):
1327
-
1328
- use_ddim = False
1329
-
1330
- log = dict()
1331
- # z_0, z_1, c, x_0, x_0_rec, x_1, x_1_rec, xc
1332
- # z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1333
- z_0, z_1, c, x_0, x_0_rec, x_1, x_1_rec, xc = self.get_input(batch, self.first_stage_key,
1334
- return_first_stage_outputs=True,
1335
- force_c_encode=True,
1336
- return_original_cond=True,
1337
- bs=N, uncond=0)
1338
- N = min(x_0.shape[0], N)
1339
- n_row = min(x_0.shape[0], n_row)
1340
- log["inputs"] = x_0
1341
- log["reals"] = xc["c_concat"]
1342
- log["reconstruction"] = x_0_rec
1343
- if self.model.conditioning_key is not None:
1344
- if hasattr(self.cond_stage_model, "decode"):
1345
- xc = self.cond_stage_model.decode(c)
1346
- log["conditioning"] = xc
1347
- elif self.cond_stage_key in ["caption"]:
1348
- xc = log_txt_as_img((x_0.shape[2], x_0.shape[3]), batch["caption"])
1349
- log["conditioning"] = xc
1350
- elif self.cond_stage_key == 'class_label':
1351
- xc = log_txt_as_img((x_0.shape[2], x_0.shape[3]), batch["human_label"])
1352
- log['conditioning'] = xc
1353
- elif isimage(xc):
1354
- log["conditioning"] = xc
1355
- if ismap(xc):
1356
- log["original_conditioning"] = self.to_rgb(xc)
1357
-
1358
- if plot_diffusion_rows:
1359
- # get diffusion row
1360
- diffusion_row = list()
1361
- z_start = z[:n_row]
1362
- for t in range(self.num_timesteps):
1363
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1364
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1365
- t = t.to(self.device).long()
1366
- noise = torch.randn_like(z_start)
1367
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1368
- diffusion_row.append(self.decode_first_stage(z_noisy))
1369
-
1370
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1371
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1372
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1373
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1374
- log["diffusion_row"] = diffusion_grid
1375
-
1376
- if sample:
1377
- # get denoise row
1378
- with self.ema_scope("Plotting"):
1379
- samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1380
- ddim_steps=ddim_steps,eta=ddim_eta)
1381
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1382
- x_samples = self.decode_first_stage(samples)
1383
- log["samples"] = x_samples
1384
- if plot_denoise_rows:
1385
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1386
- log["denoise_row"] = denoise_grid
1387
-
1388
- if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1389
- self.first_stage_model, IdentityFirstStage):
1390
- # also display when quantizing x0 while sampling
1391
- with self.ema_scope("Plotting Quantized Denoised"):
1392
- samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1393
- ddim_steps=ddim_steps,eta=ddim_eta,
1394
- quantize_denoised=True)
1395
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1396
- # quantize_denoised=True)
1397
- x_samples = self.decode_first_stage(samples.to(self.device))
1398
- log["samples_x0_quantized"] = x_samples
1399
-
1400
- if inpaint:
1401
- # make a simple center square
1402
- b, h, w = z.shape[0], z.shape[2], z.shape[3]
1403
- mask = torch.ones(N, h, w).to(self.device)
1404
- # zeros will be filled in
1405
- mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1406
- mask = mask[:, None, ...]
1407
- with self.ema_scope("Plotting Inpaint"):
1408
-
1409
- samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
1410
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1411
- x_samples = self.decode_first_stage(samples.to(self.device))
1412
- log["samples_inpainting"] = x_samples
1413
- log["mask"] = mask
1414
-
1415
- # outpaint
1416
- with self.ema_scope("Plotting Outpaint"):
1417
- samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
1418
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1419
- x_samples = self.decode_first_stage(samples.to(self.device))
1420
- log["samples_outpainting"] = x_samples
1421
-
1422
- if plot_progressive_rows:
1423
- with self.ema_scope("Plotting Progressives"):
1424
- img, progressives = self.progressive_denoising(c,
1425
- shape=(self.channels, self.image_size, self.image_size),
1426
- batch_size=N)
1427
- prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1428
- log["progressive_row"] = prog_row
1429
-
1430
- if return_keys:
1431
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1432
- return log
1433
- else:
1434
- return {key: log[key] for key in return_keys}
1435
- return log
1436
-
1437
- def configure_optimizers(self):
1438
- lr = self.learning_rate
1439
- params = list(self.model.parameters())
1440
- if self.cond_stage_trainable:
1441
- print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1442
- params = params + list(self.cond_stage_model.parameters())
1443
- if self.learn_logvar:
1444
- print('Diffusion model optimizing logvar')
1445
- params.append(self.logvar)
1446
- opt = torch.optim.AdamW(params, lr=lr)
1447
- if self.use_scheduler:
1448
- assert 'target' in self.scheduler_config
1449
- scheduler = instantiate_from_config(self.scheduler_config)
1450
-
1451
- print("Setting up LambdaLR scheduler...")
1452
- scheduler = [
1453
- {
1454
- 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1455
- 'interval': 'step',
1456
- 'frequency': 1
1457
- }]
1458
- return [opt], scheduler
1459
- return opt
1460
-
1461
- @torch.no_grad()
1462
- def to_rgb(self, x):
1463
- x = x.float()
1464
- if not hasattr(self, "colorize"):
1465
- self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1466
- x = nn.functional.conv2d(x, weight=self.colorize)
1467
- x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1468
- return x
1469
-
1470
-
1471
- class DiffusionWrapper(pl.LightningModule):
1472
- def __init__(self, diff_model_config, conditioning_key):
1473
- super().__init__()
1474
- self.diffusion_model = instantiate_from_config(diff_model_config)
1475
- self.conditioning_key = conditioning_key
1476
- assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
1477
-
1478
- def forward(self, x_0, x_1, t, c_concat: list = None, c_crossattn: list = None):
1479
- if self.conditioning_key is None:
1480
- out = self.diffusion_model(x, t)
1481
- elif self.conditioning_key == 'concat':
1482
- xc = torch.cat([x] + c_concat, dim=1)
1483
- out = self.diffusion_model(xc, t)
1484
- elif self.conditioning_key == 'crossattn':
1485
- cc = torch.cat(c_crossattn, 1)
1486
- out = self.diffusion_model(x, t, context=cc)
1487
- elif self.conditioning_key == 'hybrid':
1488
- xc_0 = torch.cat([x_0] + c_concat, dim=1)
1489
- xc_1 = torch.cat([x_0, x_1] + c_concat, dim=1)
1490
- cc = torch.cat(c_crossattn, 1)
1491
- out_1, out_2 = self.diffusion_model(xc_0, xc_1, t, context=cc)
1492
- elif self.conditioning_key == 'adm':
1493
- cc = c_crossattn[0]
1494
- out = self.diffusion_model(x, t, y=cc)
1495
- else:
1496
- raise NotImplementedError()
1497
-
1498
- return out_1, out_2
1499
-
1500
-
1501
- class Layout2ImgDiffusion(LatentDiffusion):
1502
- # TODO: move all layout-specific hacks to this class
1503
- def __init__(self, cond_stage_key, *args, **kwargs):
1504
- assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
1505
- super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
1506
-
1507
- def log_images(self, batch, N=8, *args, **kwargs):
1508
- logs = super().log_images(batch=batch, N=N, *args, **kwargs)
1509
-
1510
- key = 'train' if self.training else 'validation'
1511
- dset = self.trainer.datamodule.datasets[key]
1512
- mapper = dset.conditional_builders[self.cond_stage_key]
1513
-
1514
- bbox_imgs = []
1515
- map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
1516
- for tknzd_bbox in batch[self.cond_stage_key][:N]:
1517
- bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
1518
- bbox_imgs.append(bboximg)
1519
-
1520
- cond_img = torch.stack(bbox_imgs, dim=0)
1521
- logs['bbox_image'] = cond_img
1522
- return logs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable_diffusion/ldm/modules/attention.py CHANGED
@@ -1,6 +1,3 @@
1
- # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
2
- # See more details in LICENSE.
3
-
4
  from inspect import isfunction
5
  import math
6
  import torch
 
 
 
 
1
  from inspect import isfunction
2
  import math
3
  import torch
stable_diffusion/ldm/modules/diffusionmodules/openaimodel_pam.py DELETED
@@ -1,1040 +0,0 @@
1
- from abc import abstractmethod
2
- from functools import partial
3
- import math
4
- from typing import Iterable
5
-
6
- import numpy as np
7
- import torch as th
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
-
11
- from ldm.modules.diffusionmodules.util import (
12
- checkpoint,
13
- conv_nd,
14
- linear,
15
- avg_pool_nd,
16
- zero_module,
17
- normalization,
18
- timestep_embedding,
19
- )
20
- from ldm.modules.attention import SpatialTransformer
21
-
22
-
23
- # dummy replace
24
- def convert_module_to_f16(x):
25
- pass
26
-
27
- def convert_module_to_f32(x):
28
- pass
29
-
30
-
31
- ## go
32
- class AttentionPool2d(nn.Module):
33
- """
34
- Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
35
- """
36
-
37
- def __init__(
38
- self,
39
- spacial_dim: int,
40
- embed_dim: int,
41
- num_heads_channels: int,
42
- output_dim: int = None,
43
- ):
44
- super().__init__()
45
- self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
46
- self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
47
- self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
48
- self.num_heads = embed_dim // num_heads_channels
49
- self.attention = QKVAttention(self.num_heads)
50
-
51
- def forward(self, x):
52
- b, c, *_spatial = x.shape
53
- x = x.reshape(b, c, -1) # NC(HW)
54
- x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
55
- x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
56
- x = self.qkv_proj(x)
57
- x = self.attention(x)
58
- x = self.c_proj(x)
59
- return x[:, :, 0]
60
-
61
-
62
- class TimestepBlock(nn.Module):
63
- """
64
- Any module where forward() takes timestep embeddings as a second argument.
65
- """
66
-
67
- @abstractmethod
68
- def forward(self, x, emb):
69
- """
70
- Apply the module to `x` given `emb` timestep embeddings.
71
- """
72
-
73
-
74
- class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
75
- """
76
- A sequential module that passes timestep embeddings to the children that
77
- support it as an extra input.
78
- """
79
-
80
- def forward(self, x, emb, context=None):
81
- for layer in self:
82
- if isinstance(layer, TimestepBlock):
83
- x = layer(x, emb)
84
- elif isinstance(layer, SpatialTransformer):
85
- x = layer(x, context)
86
- else:
87
- x = layer(x)
88
- return x
89
-
90
-
91
- class Upsample(nn.Module):
92
- """
93
- An upsampling layer with an optional convolution.
94
- :param channels: channels in the inputs and outputs.
95
- :param use_conv: a bool determining if a convolution is applied.
96
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
97
- upsampling occurs in the inner-two dimensions.
98
- """
99
-
100
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
101
- super().__init__()
102
- self.channels = channels
103
- self.out_channels = out_channels or channels
104
- self.use_conv = use_conv
105
- self.dims = dims
106
- if use_conv:
107
- self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
108
-
109
- def forward(self, x):
110
- assert x.shape[1] == self.channels
111
- if self.dims == 3:
112
- x = F.interpolate(
113
- x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
114
- )
115
- else:
116
- x = F.interpolate(x, scale_factor=2, mode="nearest")
117
- if self.use_conv:
118
- x = self.conv(x)
119
- return x
120
-
121
- class TransposedUpsample(nn.Module):
122
- 'Learned 2x upsampling without padding'
123
- def __init__(self, channels, out_channels=None, ks=5):
124
- super().__init__()
125
- self.channels = channels
126
- self.out_channels = out_channels or channels
127
-
128
- self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
129
-
130
- def forward(self,x):
131
- return self.up(x)
132
-
133
-
134
- class Downsample(nn.Module):
135
- """
136
- A downsampling layer with an optional convolution.
137
- :param channels: channels in the inputs and outputs.
138
- :param use_conv: a bool determining if a convolution is applied.
139
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
140
- downsampling occurs in the inner-two dimensions.
141
- """
142
-
143
- def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
144
- super().__init__()
145
- self.channels = channels
146
- self.out_channels = out_channels or channels
147
- self.use_conv = use_conv
148
- self.dims = dims
149
- stride = 2 if dims != 3 else (1, 2, 2)
150
- if use_conv:
151
- self.op = conv_nd(
152
- dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
153
- )
154
- else:
155
- assert self.channels == self.out_channels
156
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
157
-
158
- def forward(self, x):
159
- assert x.shape[1] == self.channels
160
- return self.op(x)
161
-
162
-
163
- class ResBlock(TimestepBlock):
164
- """
165
- A residual block that can optionally change the number of channels.
166
- :param channels: the number of input channels.
167
- :param emb_channels: the number of timestep embedding channels.
168
- :param dropout: the rate of dropout.
169
- :param out_channels: if specified, the number of out channels.
170
- :param use_conv: if True and out_channels is specified, use a spatial
171
- convolution instead of a smaller 1x1 convolution to change the
172
- channels in the skip connection.
173
- :param dims: determines if the signal is 1D, 2D, or 3D.
174
- :param use_checkpoint: if True, use gradient checkpointing on this module.
175
- :param up: if True, use this block for upsampling.
176
- :param down: if True, use this block for downsampling.
177
- """
178
-
179
- def __init__(
180
- self,
181
- channels,
182
- emb_channels,
183
- dropout,
184
- out_channels=None,
185
- use_conv=False,
186
- use_scale_shift_norm=False,
187
- dims=2,
188
- use_checkpoint=False,
189
- up=False,
190
- down=False,
191
- ):
192
- super().__init__()
193
- self.channels = channels
194
- self.emb_channels = emb_channels
195
- self.dropout = dropout
196
- self.out_channels = out_channels or channels
197
- self.use_conv = use_conv
198
- self.use_checkpoint = use_checkpoint
199
- self.use_scale_shift_norm = use_scale_shift_norm
200
-
201
- self.in_layers = nn.Sequential(
202
- normalization(channels),
203
- nn.SiLU(),
204
- conv_nd(dims, channels, self.out_channels, 3, padding=1),
205
- )
206
-
207
- self.updown = up or down
208
-
209
- if up:
210
- self.h_upd = Upsample(channels, False, dims)
211
- self.x_upd = Upsample(channels, False, dims)
212
- elif down:
213
- self.h_upd = Downsample(channels, False, dims)
214
- self.x_upd = Downsample(channels, False, dims)
215
- else:
216
- self.h_upd = self.x_upd = nn.Identity()
217
-
218
- self.emb_layers = nn.Sequential(
219
- nn.SiLU(),
220
- linear(
221
- emb_channels,
222
- 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
223
- ),
224
- )
225
- self.out_layers = nn.Sequential(
226
- normalization(self.out_channels),
227
- nn.SiLU(),
228
- nn.Dropout(p=dropout),
229
- zero_module(
230
- conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
231
- ),
232
- )
233
-
234
- if self.out_channels == channels:
235
- self.skip_connection = nn.Identity()
236
- elif use_conv:
237
- self.skip_connection = conv_nd(
238
- dims, channels, self.out_channels, 3, padding=1
239
- )
240
- else:
241
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
242
-
243
- def forward(self, x, emb):
244
- """
245
- Apply the block to a Tensor, conditioned on a timestep embedding.
246
- :param x: an [N x C x ...] Tensor of features.
247
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
248
- :return: an [N x C x ...] Tensor of outputs.
249
- """
250
- return checkpoint(
251
- self._forward, (x, emb), self.parameters(), self.use_checkpoint
252
- )
253
-
254
-
255
- def _forward(self, x, emb):
256
- if self.updown:
257
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
258
- h = in_rest(x)
259
- h = self.h_upd(h)
260
- x = self.x_upd(x)
261
- h = in_conv(h)
262
- else:
263
- h = self.in_layers(x)
264
- emb_out = self.emb_layers(emb).type(h.dtype)
265
- while len(emb_out.shape) < len(h.shape):
266
- emb_out = emb_out[..., None]
267
- if self.use_scale_shift_norm:
268
- out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
269
- scale, shift = th.chunk(emb_out, 2, dim=1)
270
- h = out_norm(h) * (1 + scale) + shift
271
- h = out_rest(h)
272
- else:
273
- h = h + emb_out
274
- h = self.out_layers(h)
275
- return self.skip_connection(x) + h
276
-
277
-
278
- class AttentionBlock(nn.Module):
279
- """
280
- An attention block that allows spatial positions to attend to each other.
281
- Originally ported from here, but adapted to the N-d case.
282
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
283
- """
284
-
285
- def __init__(
286
- self,
287
- channels,
288
- num_heads=1,
289
- num_head_channels=-1,
290
- use_checkpoint=False,
291
- use_new_attention_order=False,
292
- ):
293
- super().__init__()
294
- self.channels = channels
295
- if num_head_channels == -1:
296
- self.num_heads = num_heads
297
- else:
298
- assert (
299
- channels % num_head_channels == 0
300
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
301
- self.num_heads = channels // num_head_channels
302
- self.use_checkpoint = use_checkpoint
303
- self.norm = normalization(channels)
304
- self.qkv = conv_nd(1, channels, channels * 3, 1)
305
- if use_new_attention_order:
306
- # split qkv before split heads
307
- self.attention = QKVAttention(self.num_heads)
308
- else:
309
- # split heads before split qkv
310
- self.attention = QKVAttentionLegacy(self.num_heads)
311
-
312
- self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
313
-
314
- def forward(self, x):
315
- return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
316
- #return pt_checkpoint(self._forward, x) # pytorch
317
-
318
- def _forward(self, x):
319
- b, c, *spatial = x.shape
320
- x = x.reshape(b, c, -1)
321
- qkv = self.qkv(self.norm(x))
322
- h = self.attention(qkv)
323
- h = self.proj_out(h)
324
- return (x + h).reshape(b, c, *spatial)
325
-
326
-
327
- def count_flops_attn(model, _x, y):
328
- """
329
- A counter for the `thop` package to count the operations in an
330
- attention operation.
331
- Meant to be used like:
332
- macs, params = thop.profile(
333
- model,
334
- inputs=(inputs, timestamps),
335
- custom_ops={QKVAttention: QKVAttention.count_flops},
336
- )
337
- """
338
- b, c, *spatial = y[0].shape
339
- num_spatial = int(np.prod(spatial))
340
- # We perform two matmuls with the same number of ops.
341
- # The first computes the weight matrix, the second computes
342
- # the combination of the value vectors.
343
- matmul_ops = 2 * b * (num_spatial ** 2) * c
344
- model.total_ops += th.DoubleTensor([matmul_ops])
345
-
346
-
347
- class QKVAttentionLegacy(nn.Module):
348
- """
349
- A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
350
- """
351
-
352
- def __init__(self, n_heads):
353
- super().__init__()
354
- self.n_heads = n_heads
355
-
356
- def forward(self, qkv):
357
- """
358
- Apply QKV attention.
359
- :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
360
- :return: an [N x (H * C) x T] tensor after attention.
361
- """
362
- bs, width, length = qkv.shape
363
- assert width % (3 * self.n_heads) == 0
364
- ch = width // (3 * self.n_heads)
365
- q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
366
- scale = 1 / math.sqrt(math.sqrt(ch))
367
- weight = th.einsum(
368
- "bct,bcs->bts", q * scale, k * scale
369
- ) # More stable with f16 than dividing afterwards
370
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
371
- a = th.einsum("bts,bcs->bct", weight, v)
372
- return a.reshape(bs, -1, length)
373
-
374
- @staticmethod
375
- def count_flops(model, _x, y):
376
- return count_flops_attn(model, _x, y)
377
-
378
-
379
- class QKVAttention(nn.Module):
380
- """
381
- A module which performs QKV attention and splits in a different order.
382
- """
383
-
384
- def __init__(self, n_heads):
385
- super().__init__()
386
- self.n_heads = n_heads
387
-
388
- def forward(self, qkv):
389
- """
390
- Apply QKV attention.
391
- :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
392
- :return: an [N x (H * C) x T] tensor after attention.
393
- """
394
- bs, width, length = qkv.shape
395
- assert width % (3 * self.n_heads) == 0
396
- ch = width // (3 * self.n_heads)
397
- q, k, v = qkv.chunk(3, dim=1)
398
- scale = 1 / math.sqrt(math.sqrt(ch))
399
- weight = th.einsum(
400
- "bct,bcs->bts",
401
- (q * scale).view(bs * self.n_heads, ch, length),
402
- (k * scale).view(bs * self.n_heads, ch, length),
403
- ) # More stable with f16 than dividing afterwards
404
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
405
- a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
406
- return a.reshape(bs, -1, length)
407
-
408
- @staticmethod
409
- def count_flops(model, _x, y):
410
- return count_flops_attn(model, _x, y)
411
-
412
-
413
- class UNetModel(nn.Module):
414
- """
415
- The full UNet model with attention and timestep embedding.
416
- :param in_channels: channels in the input Tensor.
417
- :param model_channels: base channel count for the model.
418
- :param out_channels: channels in the output Tensor.
419
- :param num_res_blocks: number of residual blocks per downsample.
420
- :param attention_resolutions: a collection of downsample rates at which
421
- attention will take place. May be a set, list, or tuple.
422
- For example, if this contains 4, then at 4x downsampling, attention
423
- will be used.
424
- :param dropout: the dropout probability.
425
- :param channel_mult: channel multiplier for each level of the UNet.
426
- :param conv_resample: if True, use learned convolutions for upsampling and
427
- downsampling.
428
- :param dims: determines if the signal is 1D, 2D, or 3D.
429
- :param num_classes: if specified (as an int), then this model will be
430
- class-conditional with `num_classes` classes.
431
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
432
- :param num_heads: the number of attention heads in each attention layer.
433
- :param num_heads_channels: if specified, ignore num_heads and instead use
434
- a fixed channel width per attention head.
435
- :param num_heads_upsample: works with num_heads to set a different number
436
- of heads for upsampling. Deprecated.
437
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
438
- :param resblock_updown: use residual blocks for up/downsampling.
439
- :param use_new_attention_order: use a different attention pattern for potentially
440
- increased efficiency.
441
- """
442
-
443
- def __init__(
444
- self,
445
- image_size,
446
- in_channels,
447
- model_channels,
448
- out_channels,
449
- num_res_blocks,
450
- attention_resolutions,
451
- in_mask_channels=0,
452
- dropout=0,
453
- channel_mult=(1, 2, 4, 8),
454
- conv_resample=True,
455
- dims=2,
456
- num_classes=None,
457
- use_checkpoint=False,
458
- use_fp16=False,
459
- num_heads=-1,
460
- num_head_channels=-1,
461
- num_heads_upsample=-1,
462
- use_scale_shift_norm=False,
463
- resblock_updown=False,
464
- use_new_attention_order=False,
465
- use_spatial_transformer=False, # custom transformer support
466
- transformer_depth=1, # custom transformer support
467
- context_dim=None, # custom transformer support
468
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
469
- legacy=True,
470
- independent_blocks_num=1, # custom support for independent blocks
471
- ):
472
- super().__init__()
473
- if use_spatial_transformer:
474
- assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
475
-
476
- if context_dim is not None:
477
- assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
478
- from omegaconf.listconfig import ListConfig
479
- if type(context_dim) == ListConfig:
480
- context_dim = list(context_dim)
481
-
482
- if num_heads_upsample == -1:
483
- num_heads_upsample = num_heads
484
-
485
- if num_heads == -1:
486
- assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
487
-
488
- if num_head_channels == -1:
489
- assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
490
-
491
- self.image_size = image_size
492
- self.in_channels = in_channels
493
- self.in_mask_channels = in_mask_channels
494
- self.model_channels = model_channels
495
- self.out_channels = out_channels
496
- self.num_res_blocks = num_res_blocks
497
- self.attention_resolutions = attention_resolutions
498
- self.dropout = dropout
499
- self.channel_mult = channel_mult
500
- self.conv_resample = conv_resample
501
- self.num_classes = num_classes
502
- self.use_checkpoint = use_checkpoint
503
- self.dtype = th.float16 if use_fp16 else th.float32
504
- self.num_heads = num_heads
505
- self.num_head_channels = num_head_channels
506
- self.num_heads_upsample = num_heads_upsample
507
- self.predict_codebook_ids = n_embed is not None
508
- self.independent_blocks_num = independent_blocks_num
509
- assert self.independent_blocks_num > 0 and self.independent_blocks_num <= len(channel_mult), 'Number of independent blocks should be between 1 and the number of blocks'
510
-
511
- time_embed_dim = model_channels * 4
512
- self.time_embed = nn.Sequential(
513
- linear(model_channels, time_embed_dim),
514
- nn.SiLU(),
515
- linear(time_embed_dim, time_embed_dim),
516
- )
517
-
518
- if self.num_classes is not None:
519
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
520
-
521
- self.input_blocks = nn.ModuleList(
522
- [
523
- TimestepEmbedSequential(
524
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
525
- )
526
- ]
527
- )
528
- self.input_blocks_branch_1 = nn.ModuleList(
529
- [
530
- TimestepEmbedSequential(
531
- conv_nd(dims,in_mask_channels if in_mask_channels != 0 else in_channels, model_channels, 3, padding=1)
532
- )
533
- ]
534
- )
535
- self.input_blocks_branch_1_available = [True]
536
- self._feature_size = model_channels
537
- input_block_chans = [model_channels]
538
- ch = model_channels
539
- ds = 1
540
- for level, mult in enumerate(channel_mult):
541
- for _ in range(num_res_blocks):
542
- layers = [
543
- ResBlock(
544
- ch,
545
- time_embed_dim,
546
- dropout,
547
- out_channels=mult * model_channels,
548
- dims=dims,
549
- use_checkpoint=use_checkpoint,
550
- use_scale_shift_norm=use_scale_shift_norm,
551
- )
552
- ]
553
- ch = mult * model_channels
554
- if ds in attention_resolutions:
555
- if num_head_channels == -1:
556
- dim_head = ch // num_heads
557
- else:
558
- num_heads = ch // num_head_channels
559
- dim_head = num_head_channels
560
- if legacy:
561
- #num_heads = 1
562
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
563
- layers.append(
564
- AttentionBlock(
565
- ch,
566
- use_checkpoint=use_checkpoint,
567
- num_heads=num_heads,
568
- num_head_channels=dim_head,
569
- use_new_attention_order=use_new_attention_order,
570
- ) if not use_spatial_transformer else SpatialTransformer(
571
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
572
- )
573
- )
574
- self.input_blocks.append(TimestepEmbedSequential(*layers))
575
- if level < self.independent_blocks_num:
576
- self.input_blocks_branch_1.append(TimestepEmbedSequential(*layers))
577
- self.input_blocks_branch_1_available.append(True)
578
- else:
579
- self.input_blocks_branch_1.append(nn.Sequential(nn.Identity()))
580
- self.input_blocks_branch_1_available.append(False)
581
- self._feature_size += ch
582
- input_block_chans.append(ch)
583
- if level != len(channel_mult) - 1:
584
- out_ch = ch
585
- self.input_blocks.append(
586
- TimestepEmbedSequential(
587
- ResBlock(
588
- ch,
589
- time_embed_dim,
590
- dropout,
591
- out_channels=out_ch,
592
- dims=dims,
593
- use_checkpoint=use_checkpoint,
594
- use_scale_shift_norm=use_scale_shift_norm,
595
- down=True,
596
- )
597
- if resblock_updown
598
- else Downsample(
599
- ch, conv_resample, dims=dims, out_channels=out_ch
600
- )
601
- )
602
- )
603
- if level < self.independent_blocks_num - 1:
604
- self.input_blocks_branch_1.append(
605
- TimestepEmbedSequential(
606
- ResBlock(
607
- ch,
608
- time_embed_dim,
609
- dropout,
610
- out_channels=out_ch,
611
- dims=dims,
612
- use_checkpoint=use_checkpoint,
613
- use_scale_shift_norm=use_scale_shift_norm,
614
- down=True,
615
- )
616
- if resblock_updown
617
- else Downsample(
618
- ch, conv_resample, dims=dims, out_channels=out_ch
619
- )
620
- )
621
- )
622
- self.input_blocks_branch_1_available.append(True)
623
- else:
624
- self.input_blocks_branch_1.append(nn.Sequential(nn.Identity()))
625
- self.input_blocks_branch_1_available.append(False)
626
- ch = out_ch
627
- input_block_chans.append(ch)
628
- ds *= 2
629
- self._feature_size += ch
630
- if num_head_channels == -1:
631
- dim_head = ch // num_heads
632
- else:
633
- num_heads = ch // num_head_channels
634
- dim_head = num_head_channels
635
- if legacy:
636
- #num_heads = 1
637
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
638
- self.middle_block = TimestepEmbedSequential(
639
- ResBlock(
640
- ch,
641
- time_embed_dim,
642
- dropout,
643
- dims=dims,
644
- use_checkpoint=use_checkpoint,
645
- use_scale_shift_norm=use_scale_shift_norm,
646
- ),
647
- AttentionBlock(
648
- ch,
649
- use_checkpoint=use_checkpoint,
650
- num_heads=num_heads,
651
- num_head_channels=dim_head,
652
- use_new_attention_order=use_new_attention_order,
653
- ) if not use_spatial_transformer else SpatialTransformer(
654
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
655
- ),
656
- ResBlock(
657
- ch,
658
- time_embed_dim,
659
- dropout,
660
- dims=dims,
661
- use_checkpoint=use_checkpoint,
662
- use_scale_shift_norm=use_scale_shift_norm,
663
- ),
664
- )
665
- self._feature_size += ch
666
-
667
- self.output_blocks = nn.ModuleList([])
668
- self.output_blocks_branch_1 = nn.ModuleList([])
669
- self.output_blocks_branch_1_available = []
670
-
671
- for level, mult in list(enumerate(channel_mult))[::-1]:
672
- for i in range(num_res_blocks + 1):
673
- ich = input_block_chans.pop()
674
- layers = [
675
- ResBlock(
676
- ch + ich,
677
- time_embed_dim,
678
- dropout,
679
- out_channels=model_channels * mult,
680
- dims=dims,
681
- use_checkpoint=use_checkpoint,
682
- use_scale_shift_norm=use_scale_shift_norm,
683
- )
684
- ]
685
- ch = model_channels * mult
686
- if ds in attention_resolutions:
687
- if num_head_channels == -1:
688
- dim_head = ch // num_heads
689
- else:
690
- num_heads = ch // num_head_channels
691
- dim_head = num_head_channels
692
- if legacy:
693
- #num_heads = 1
694
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
695
- layers.append(
696
- AttentionBlock(
697
- ch,
698
- use_checkpoint=use_checkpoint,
699
- num_heads=num_heads_upsample,
700
- num_head_channels=dim_head,
701
- use_new_attention_order=use_new_attention_order,
702
- ) if not use_spatial_transformer else SpatialTransformer(
703
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
704
- )
705
- )
706
- if level and i == num_res_blocks:
707
- out_ch = ch
708
- layers.append(
709
- ResBlock(
710
- ch,
711
- time_embed_dim,
712
- dropout,
713
- out_channels=out_ch,
714
- dims=dims,
715
- use_checkpoint=use_checkpoint,
716
- use_scale_shift_norm=use_scale_shift_norm,
717
- up=True,
718
- )
719
- if resblock_updown
720
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
721
- )
722
- ds //= 2
723
- self.output_blocks.append(TimestepEmbedSequential(*layers))
724
- if level < self.independent_blocks_num:
725
- self.output_blocks_branch_1.append(TimestepEmbedSequential(*layers))
726
- self.output_blocks_branch_1_available.append(True)
727
- else:
728
- self.output_blocks_branch_1.append(nn.Sequential(nn.Identity()))
729
- self.output_blocks_branch_1_available.append(False)
730
-
731
- self._feature_size += ch
732
-
733
- self.out = nn.Sequential(
734
- normalization(ch),
735
- nn.SiLU(),
736
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
737
- )
738
- self.out_branch_1 = nn.Sequential(
739
- normalization(ch),
740
- nn.SiLU(),
741
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
742
- )
743
- if self.predict_codebook_ids:
744
- self.id_predictor = nn.Sequential(
745
- normalization(ch),
746
- conv_nd(dims, model_channels, n_embed, 1),
747
- #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
748
- )
749
-
750
-
751
- def convert_to_fp16(self):
752
- """
753
- Convert the torso of the model to float16.
754
- """
755
- self.input_blocks.apply(convert_module_to_f16)
756
- self.middle_block.apply(convert_module_to_f16)
757
- self.output_blocks.apply(convert_module_to_f16)
758
-
759
- def convert_to_fp32(self):
760
- """
761
- Convert the torso of the model to float32.
762
- """
763
- self.input_blocks.apply(convert_module_to_f32)
764
- self.middle_block.apply(convert_module_to_f32)
765
- self.output_blocks.apply(convert_module_to_f32)
766
-
767
- def forward(self, x_0, x_1, timesteps=None, context=None, y=None,**kwargs):
768
- """
769
- Apply the model to an input batch.
770
- :param x_0: an [N x C x ...] Tensor of inputs.
771
- :param x_1: an [N x C x ...] Tensor of inputs.
772
- :param timesteps: a 1-D batch of timesteps.
773
- :param context: conditioning plugged in via crossattn
774
- :param y: an [N] Tensor of labels, if class-conditional.
775
- :return: an [N x C x ...] Tensor of outputs.
776
- """
777
- assert (y is not None) == (
778
- self.num_classes is not None
779
- ), "must specify y if and only if the model is class-conditional"
780
- hs_0 = []
781
- hs_1 = []
782
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
783
- emb = self.time_embed(t_emb)
784
-
785
- if self.num_classes is not None:
786
- assert y.shape == (x.shape[0],)
787
- emb = emb + self.label_emb(y)
788
-
789
- h_0 = x_0.type(self.dtype)
790
- h_1 = x_1.type(self.dtype)
791
- for index, module in enumerate(self.input_blocks):
792
- h_0 = module(h_0, emb, context)
793
-
794
- if self.input_blocks_branch_1_available[index]:
795
- module_branch_1 = self.input_blocks_branch_1[index]
796
- h_1 = module_branch_1(h_1, emb, context)
797
- else:
798
- h_1 = module(h_1, emb, context)
799
- hs_0.append(h_0)
800
- hs_1.append(h_1)
801
-
802
- h_0 = self.middle_block(h_0, emb, context)
803
- h_1 = self.middle_block(h_1, emb, context)
804
-
805
- for index, module in enumerate(self.output_blocks):
806
- h_0 = th.cat([h_0, hs_0.pop()], dim=1)
807
- h_0 = module(h_0, emb, context)
808
-
809
- h_1 = th.cat([h_1, hs_1.pop()], dim=1)
810
- if self.output_blocks_branch_1_available[index]:
811
- module_branch_1 = self.output_blocks_branch_1[index]
812
- h_1 = module_branch_1(h_1, emb, context)
813
- else:
814
- h_1 = module(h_1, emb, context)
815
-
816
- h_0 = h_0.type(x_0.dtype)
817
- h_1 = h_1.type(x_1.dtype)
818
- if self.predict_codebook_ids:
819
- return self.id_predictor(h_0), self.id_predictor(h_1)
820
- else:
821
- return self.out(h_0), self.out_branch_1(h_1)
822
-
823
-
824
- class EncoderUNetModel(nn.Module):
825
- """
826
- The half UNet model with attention and timestep embedding.
827
- For usage, see UNet.
828
- """
829
-
830
- def __init__(
831
- self,
832
- image_size,
833
- in_channels,
834
- model_channels,
835
- out_channels,
836
- num_res_blocks,
837
- attention_resolutions,
838
- dropout=0,
839
- channel_mult=(1, 2, 4, 8),
840
- conv_resample=True,
841
- dims=2,
842
- use_checkpoint=False,
843
- use_fp16=False,
844
- num_heads=1,
845
- num_head_channels=-1,
846
- num_heads_upsample=-1,
847
- use_scale_shift_norm=False,
848
- resblock_updown=False,
849
- use_new_attention_order=False,
850
- pool="adaptive",
851
- *args,
852
- **kwargs
853
- ):
854
- super().__init__()
855
-
856
- if num_heads_upsample == -1:
857
- num_heads_upsample = num_heads
858
-
859
- self.in_channels = in_channels
860
- self.model_channels = model_channels
861
- self.out_channels = out_channels
862
- self.num_res_blocks = num_res_blocks
863
- self.attention_resolutions = attention_resolutions
864
- self.dropout = dropout
865
- self.channel_mult = channel_mult
866
- self.conv_resample = conv_resample
867
- self.use_checkpoint = use_checkpoint
868
- self.dtype = th.float16 if use_fp16 else th.float32
869
- self.num_heads = num_heads
870
- self.num_head_channels = num_head_channels
871
- self.num_heads_upsample = num_heads_upsample
872
-
873
- time_embed_dim = model_channels * 4
874
- self.time_embed = nn.Sequential(
875
- linear(model_channels, time_embed_dim),
876
- nn.SiLU(),
877
- linear(time_embed_dim, time_embed_dim),
878
- )
879
-
880
- self.input_blocks = nn.ModuleList(
881
- [
882
- TimestepEmbedSequential(
883
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
884
- )
885
- ]
886
- )
887
- self._feature_size = model_channels
888
- input_block_chans = [model_channels]
889
- ch = model_channels
890
- ds = 1
891
- for level, mult in enumerate(channel_mult):
892
- for _ in range(num_res_blocks):
893
- layers = [
894
- ResBlock(
895
- ch,
896
- time_embed_dim,
897
- dropout,
898
- out_channels=mult * model_channels,
899
- dims=dims,
900
- use_checkpoint=use_checkpoint,
901
- use_scale_shift_norm=use_scale_shift_norm,
902
- )
903
- ]
904
- ch = mult * model_channels
905
- if ds in attention_resolutions:
906
- layers.append(
907
- AttentionBlock(
908
- ch,
909
- use_checkpoint=use_checkpoint,
910
- num_heads=num_heads,
911
- num_head_channels=num_head_channels,
912
- use_new_attention_order=use_new_attention_order,
913
- )
914
- )
915
- self.input_blocks.append(TimestepEmbedSequential(*layers))
916
- self._feature_size += ch
917
- input_block_chans.append(ch)
918
- if level != len(channel_mult) - 1:
919
- out_ch = ch
920
- self.input_blocks.append(
921
- TimestepEmbedSequential(
922
- ResBlock(
923
- ch,
924
- time_embed_dim,
925
- dropout,
926
- out_channels=out_ch,
927
- dims=dims,
928
- use_checkpoint=use_checkpoint,
929
- use_scale_shift_norm=use_scale_shift_norm,
930
- down=True,
931
- )
932
- if resblock_updown
933
- else Downsample(
934
- ch, conv_resample, dims=dims, out_channels=out_ch
935
- )
936
- )
937
- )
938
- ch = out_ch
939
- input_block_chans.append(ch)
940
- ds *= 2
941
- self._feature_size += ch
942
-
943
- self.middle_block = TimestepEmbedSequential(
944
- ResBlock(
945
- ch,
946
- time_embed_dim,
947
- dropout,
948
- dims=dims,
949
- use_checkpoint=use_checkpoint,
950
- use_scale_shift_norm=use_scale_shift_norm,
951
- ),
952
- AttentionBlock(
953
- ch,
954
- use_checkpoint=use_checkpoint,
955
- num_heads=num_heads,
956
- num_head_channels=num_head_channels,
957
- use_new_attention_order=use_new_attention_order,
958
- ),
959
- ResBlock(
960
- ch,
961
- time_embed_dim,
962
- dropout,
963
- dims=dims,
964
- use_checkpoint=use_checkpoint,
965
- use_scale_shift_norm=use_scale_shift_norm,
966
- ),
967
- )
968
- self._feature_size += ch
969
- self.pool = pool
970
- if pool == "adaptive":
971
- self.out = nn.Sequential(
972
- normalization(ch),
973
- nn.SiLU(),
974
- nn.AdaptiveAvgPool2d((1, 1)),
975
- zero_module(conv_nd(dims, ch, out_channels, 1)),
976
- nn.Flatten(),
977
- )
978
- elif pool == "attention":
979
- assert num_head_channels != -1
980
- self.out = nn.Sequential(
981
- normalization(ch),
982
- nn.SiLU(),
983
- AttentionPool2d(
984
- (image_size // ds), ch, num_head_channels, out_channels
985
- ),
986
- )
987
- elif pool == "spatial":
988
- self.out = nn.Sequential(
989
- nn.Linear(self._feature_size, 2048),
990
- nn.ReLU(),
991
- nn.Linear(2048, self.out_channels),
992
- )
993
- elif pool == "spatial_v2":
994
- self.out = nn.Sequential(
995
- nn.Linear(self._feature_size, 2048),
996
- normalization(2048),
997
- nn.SiLU(),
998
- nn.Linear(2048, self.out_channels),
999
- )
1000
- else:
1001
- raise NotImplementedError(f"Unexpected {pool} pooling")
1002
-
1003
- def convert_to_fp16(self):
1004
- """
1005
- Convert the torso of the model to float16.
1006
- """
1007
- self.input_blocks.apply(convert_module_to_f16)
1008
- self.middle_block.apply(convert_module_to_f16)
1009
-
1010
- def convert_to_fp32(self):
1011
- """
1012
- Convert the torso of the model to float32.
1013
- """
1014
- self.input_blocks.apply(convert_module_to_f32)
1015
- self.middle_block.apply(convert_module_to_f32)
1016
-
1017
- def forward(self, x, timesteps):
1018
- """
1019
- Apply the model to an input batch.
1020
- :param x: an [N x C x ...] Tensor of inputs.
1021
- :param timesteps: a 1-D batch of timesteps.
1022
- :return: an [N x K] Tensor of outputs.
1023
- """
1024
- emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
1025
-
1026
- results = []
1027
- h = x.type(self.dtype)
1028
- for module in self.input_blocks:
1029
- h = module(h, emb)
1030
- if self.pool.startswith("spatial"):
1031
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
1032
- h = self.middle_block(h, emb)
1033
- if self.pool.startswith("spatial"):
1034
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
1035
- h = th.cat(results, axis=-1)
1036
- return self.out(h)
1037
- else:
1038
- h = h.type(x.dtype)
1039
- return self.out(h)
1040
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable_diffusion/ldm/modules/diffusionmodules/openaimodel_pam_separate_mask.py DELETED
@@ -1,1091 +0,0 @@
1
- from abc import abstractmethod
2
- from functools import partial
3
- import math
4
- from typing import Iterable
5
-
6
- import numpy as np
7
- import torch as th
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
-
11
- from ldm.modules.diffusionmodules.util import (
12
- checkpoint,
13
- conv_nd,
14
- linear,
15
- avg_pool_nd,
16
- zero_module,
17
- normalization,
18
- timestep_embedding,
19
- )
20
- from ldm.modules.attention import SpatialTransformer
21
-
22
-
23
- # dummy replace
24
- def convert_module_to_f16(x):
25
- pass
26
-
27
- def convert_module_to_f32(x):
28
- pass
29
-
30
-
31
- ## go
32
- class AttentionPool2d(nn.Module):
33
- """
34
- Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
35
- """
36
-
37
- def __init__(
38
- self,
39
- spacial_dim: int,
40
- embed_dim: int,
41
- num_heads_channels: int,
42
- output_dim: int = None,
43
- ):
44
- super().__init__()
45
- self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
46
- self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
47
- self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
48
- self.num_heads = embed_dim // num_heads_channels
49
- self.attention = QKVAttention(self.num_heads)
50
-
51
- def forward(self, x):
52
- b, c, *_spatial = x.shape
53
- x = x.reshape(b, c, -1) # NC(HW)
54
- x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
55
- x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
56
- x = self.qkv_proj(x)
57
- x = self.attention(x)
58
- x = self.c_proj(x)
59
- return x[:, :, 0]
60
-
61
-
62
- class TimestepBlock(nn.Module):
63
- """
64
- Any module where forward() takes timestep embeddings as a second argument.
65
- """
66
-
67
- @abstractmethod
68
- def forward(self, x, emb):
69
- """
70
- Apply the module to `x` given `emb` timestep embeddings.
71
- """
72
-
73
-
74
- class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
75
- """
76
- A sequential module that passes timestep embeddings to the children that
77
- support it as an extra input.
78
- """
79
-
80
- def forward(self, x, emb, context=None):
81
- for layer in self:
82
- if isinstance(layer, TimestepBlock):
83
- x = layer(x, emb)
84
- elif isinstance(layer, SpatialTransformer):
85
- x = layer(x, context)
86
- else:
87
- x = layer(x)
88
- return x
89
-
90
-
91
- class Upsample(nn.Module):
92
- """
93
- An upsampling layer with an optional convolution.
94
- :param channels: channels in the inputs and outputs.
95
- :param use_conv: a bool determining if a convolution is applied.
96
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
97
- upsampling occurs in the inner-two dimensions.
98
- """
99
-
100
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
101
- super().__init__()
102
- self.channels = channels
103
- self.out_channels = out_channels or channels
104
- self.use_conv = use_conv
105
- self.dims = dims
106
- if use_conv:
107
- self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
108
-
109
- def forward(self, x):
110
- assert x.shape[1] == self.channels
111
- if self.dims == 3:
112
- x = F.interpolate(
113
- x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
114
- )
115
- else:
116
- x = F.interpolate(x, scale_factor=2, mode="nearest")
117
- if self.use_conv:
118
- x = self.conv(x)
119
- return x
120
-
121
- class TransposedUpsample(nn.Module):
122
- 'Learned 2x upsampling without padding'
123
- def __init__(self, channels, out_channels=None, ks=5):
124
- super().__init__()
125
- self.channels = channels
126
- self.out_channels = out_channels or channels
127
-
128
- self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
129
-
130
- def forward(self,x):
131
- return self.up(x)
132
-
133
-
134
- class Downsample(nn.Module):
135
- """
136
- A downsampling layer with an optional convolution.
137
- :param channels: channels in the inputs and outputs.
138
- :param use_conv: a bool determining if a convolution is applied.
139
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
140
- downsampling occurs in the inner-two dimensions.
141
- """
142
-
143
- def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
144
- super().__init__()
145
- self.channels = channels
146
- self.out_channels = out_channels or channels
147
- self.use_conv = use_conv
148
- self.dims = dims
149
- stride = 2 if dims != 3 else (1, 2, 2)
150
- if use_conv:
151
- self.op = conv_nd(
152
- dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
153
- )
154
- else:
155
- assert self.channels == self.out_channels
156
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
157
-
158
- def forward(self, x):
159
- assert x.shape[1] == self.channels
160
- return self.op(x)
161
-
162
-
163
- class ResBlockWithoutEmb(nn.Module):
164
- """
165
- A residual block that can optionally change the number of channels.
166
- :param channels: the number of input channels.
167
- :param dropout: the rate of dropout.
168
- :param out_channels: if specified, the number of out channels.
169
- :param use_conv: if True and out_channels is specified, use a spatial
170
- convolution instead of a smaller 1x1 convolution to change the
171
- channels in the skip connection.
172
- :param dims: determines if the signal is 1D, 2D, or 3D.
173
- :param use_checkpoint: if True, use gradient checkpointing on this module.
174
- :param up: if True, use this block for upsampling.
175
- :param down: if True, use this block for downsampling.
176
- """
177
-
178
- def __init__(
179
- self,
180
- channels,
181
- dropout,
182
- out_channels=None,
183
- use_conv=False,
184
- use_scale_shift_norm=False,
185
- dims=2,
186
- use_checkpoint=False,
187
- ):
188
- super().__init__()
189
- self.channels = channels
190
- self.dropout = dropout
191
- self.out_channels = out_channels or channels
192
- self.use_conv = use_conv
193
- self.use_checkpoint = use_checkpoint
194
- self.use_scale_shift_norm = use_scale_shift_norm
195
-
196
- self.in_layers = nn.Sequential(
197
- normalization(channels),
198
- nn.SiLU(),
199
- conv_nd(dims, channels, self.out_channels, 3, padding=1),
200
- )
201
-
202
- self.out_layers = nn.Sequential(
203
- normalization(self.out_channels),
204
- nn.SiLU(),
205
- nn.Dropout(p=dropout),
206
- zero_module(
207
- conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
208
- ),
209
- )
210
-
211
- if self.out_channels == channels:
212
- self.skip_connection = nn.Identity()
213
- elif use_conv:
214
- self.skip_connection = conv_nd(
215
- dims, channels, self.out_channels, 3, padding=1
216
- )
217
- else:
218
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
219
-
220
- def forward(self, x):
221
- """
222
- Apply the block to a Tensor, conditioned on a timestep embedding.
223
- :param x: an [N x C x ...] Tensor of features.
224
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
225
- :return: an [N x C x ...] Tensor of outputs.
226
- """
227
- return checkpoint(
228
- self._forward, (x,), self.parameters(), self.use_checkpoint
229
- )
230
-
231
-
232
- def _forward(self, x):
233
- h = self.in_layers(x)
234
- h = self.out_layers(h)
235
- return self.skip_connection(x) + h
236
-
237
-
238
-
239
- class ResBlock(TimestepBlock):
240
- """
241
- A residual block that can optionally change the number of channels.
242
- :param channels: the number of input channels.
243
- :param emb_channels: the number of timestep embedding channels.
244
- :param dropout: the rate of dropout.
245
- :param out_channels: if specified, the number of out channels.
246
- :param use_conv: if True and out_channels is specified, use a spatial
247
- convolution instead of a smaller 1x1 convolution to change the
248
- channels in the skip connection.
249
- :param dims: determines if the signal is 1D, 2D, or 3D.
250
- :param use_checkpoint: if True, use gradient checkpointing on this module.
251
- :param up: if True, use this block for upsampling.
252
- :param down: if True, use this block for downsampling.
253
- """
254
-
255
- def __init__(
256
- self,
257
- channels,
258
- emb_channels,
259
- dropout,
260
- out_channels=None,
261
- use_conv=False,
262
- use_scale_shift_norm=False,
263
- dims=2,
264
- use_checkpoint=False,
265
- up=False,
266
- down=False,
267
- ):
268
- super().__init__()
269
- self.channels = channels
270
- self.emb_channels = emb_channels
271
- self.dropout = dropout
272
- self.out_channels = out_channels or channels
273
- self.use_conv = use_conv
274
- self.use_checkpoint = use_checkpoint
275
- self.use_scale_shift_norm = use_scale_shift_norm
276
-
277
- self.in_layers = nn.Sequential(
278
- normalization(channels),
279
- nn.SiLU(),
280
- conv_nd(dims, channels, self.out_channels, 3, padding=1),
281
- )
282
-
283
- self.updown = up or down
284
-
285
- if up:
286
- self.h_upd = Upsample(channels, False, dims)
287
- self.x_upd = Upsample(channels, False, dims)
288
- elif down:
289
- self.h_upd = Downsample(channels, False, dims)
290
- self.x_upd = Downsample(channels, False, dims)
291
- else:
292
- self.h_upd = self.x_upd = nn.Identity()
293
-
294
- self.emb_layers = nn.Sequential(
295
- nn.SiLU(),
296
- linear(
297
- emb_channels,
298
- 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
299
- ),
300
- )
301
- self.out_layers = nn.Sequential(
302
- normalization(self.out_channels),
303
- nn.SiLU(),
304
- nn.Dropout(p=dropout),
305
- zero_module(
306
- conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
307
- ),
308
- )
309
-
310
- if self.out_channels == channels:
311
- self.skip_connection = nn.Identity()
312
- elif use_conv:
313
- self.skip_connection = conv_nd(
314
- dims, channels, self.out_channels, 3, padding=1
315
- )
316
- else:
317
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
318
-
319
- def forward(self, x, emb):
320
- """
321
- Apply the block to a Tensor, conditioned on a timestep embedding.
322
- :param x: an [N x C x ...] Tensor of features.
323
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
324
- :return: an [N x C x ...] Tensor of outputs.
325
- """
326
- return checkpoint(
327
- self._forward, (x, emb), self.parameters(), self.use_checkpoint
328
- )
329
-
330
-
331
- def _forward(self, x, emb):
332
- if self.updown:
333
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
334
- h = in_rest(x)
335
- h = self.h_upd(h)
336
- x = self.x_upd(x)
337
- h = in_conv(h)
338
- else:
339
- h = self.in_layers(x)
340
- emb_out = self.emb_layers(emb).type(h.dtype)
341
- while len(emb_out.shape) < len(h.shape):
342
- emb_out = emb_out[..., None]
343
- if self.use_scale_shift_norm:
344
- out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
345
- scale, shift = th.chunk(emb_out, 2, dim=1)
346
- h = out_norm(h) * (1 + scale) + shift
347
- h = out_rest(h)
348
- else:
349
- h = h + emb_out
350
- h = self.out_layers(h)
351
- return self.skip_connection(x) + h
352
-
353
-
354
- class AttentionBlock(nn.Module):
355
- """
356
- An attention block that allows spatial positions to attend to each other.
357
- Originally ported from here, but adapted to the N-d case.
358
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
359
- """
360
-
361
- def __init__(
362
- self,
363
- channels,
364
- num_heads=1,
365
- num_head_channels=-1,
366
- use_checkpoint=False,
367
- use_new_attention_order=False,
368
- ):
369
- super().__init__()
370
- self.channels = channels
371
- if num_head_channels == -1:
372
- self.num_heads = num_heads
373
- else:
374
- assert (
375
- channels % num_head_channels == 0
376
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
377
- self.num_heads = channels // num_head_channels
378
- self.use_checkpoint = use_checkpoint
379
- self.norm = normalization(channels)
380
- self.qkv = conv_nd(1, channels, channels * 3, 1)
381
- if use_new_attention_order:
382
- # split qkv before split heads
383
- self.attention = QKVAttention(self.num_heads)
384
- else:
385
- # split heads before split qkv
386
- self.attention = QKVAttentionLegacy(self.num_heads)
387
-
388
- self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
389
-
390
- def forward(self, x):
391
- return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
392
- #return pt_checkpoint(self._forward, x) # pytorch
393
-
394
- def _forward(self, x):
395
- b, c, *spatial = x.shape
396
- x = x.reshape(b, c, -1)
397
- qkv = self.qkv(self.norm(x))
398
- h = self.attention(qkv)
399
- h = self.proj_out(h)
400
- return (x + h).reshape(b, c, *spatial)
401
-
402
-
403
- def count_flops_attn(model, _x, y):
404
- """
405
- A counter for the `thop` package to count the operations in an
406
- attention operation.
407
- Meant to be used like:
408
- macs, params = thop.profile(
409
- model,
410
- inputs=(inputs, timestamps),
411
- custom_ops={QKVAttention: QKVAttention.count_flops},
412
- )
413
- """
414
- b, c, *spatial = y[0].shape
415
- num_spatial = int(np.prod(spatial))
416
- # We perform two matmuls with the same number of ops.
417
- # The first computes the weight matrix, the second computes
418
- # the combination of the value vectors.
419
- matmul_ops = 2 * b * (num_spatial ** 2) * c
420
- model.total_ops += th.DoubleTensor([matmul_ops])
421
-
422
-
423
- class QKVAttentionLegacy(nn.Module):
424
- """
425
- A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
426
- """
427
-
428
- def __init__(self, n_heads):
429
- super().__init__()
430
- self.n_heads = n_heads
431
-
432
- def forward(self, qkv):
433
- """
434
- Apply QKV attention.
435
- :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
436
- :return: an [N x (H * C) x T] tensor after attention.
437
- """
438
- bs, width, length = qkv.shape
439
- assert width % (3 * self.n_heads) == 0
440
- ch = width // (3 * self.n_heads)
441
- q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
442
- scale = 1 / math.sqrt(math.sqrt(ch))
443
- weight = th.einsum(
444
- "bct,bcs->bts", q * scale, k * scale
445
- ) # More stable with f16 than dividing afterwards
446
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
447
- a = th.einsum("bts,bcs->bct", weight, v)
448
- return a.reshape(bs, -1, length)
449
-
450
- @staticmethod
451
- def count_flops(model, _x, y):
452
- return count_flops_attn(model, _x, y)
453
-
454
-
455
- class QKVAttention(nn.Module):
456
- """
457
- A module which performs QKV attention and splits in a different order.
458
- """
459
-
460
- def __init__(self, n_heads):
461
- super().__init__()
462
- self.n_heads = n_heads
463
-
464
- def forward(self, qkv):
465
- """
466
- Apply QKV attention.
467
- :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
468
- :return: an [N x (H * C) x T] tensor after attention.
469
- """
470
- bs, width, length = qkv.shape
471
- assert width % (3 * self.n_heads) == 0
472
- ch = width // (3 * self.n_heads)
473
- q, k, v = qkv.chunk(3, dim=1)
474
- scale = 1 / math.sqrt(math.sqrt(ch))
475
- weight = th.einsum(
476
- "bct,bcs->bts",
477
- (q * scale).view(bs * self.n_heads, ch, length),
478
- (k * scale).view(bs * self.n_heads, ch, length),
479
- ) # More stable with f16 than dividing afterwards
480
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
481
- a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
482
- return a.reshape(bs, -1, length)
483
-
484
- @staticmethod
485
- def count_flops(model, _x, y):
486
- return count_flops_attn(model, _x, y)
487
-
488
-
489
- class UNetModel(nn.Module):
490
- """
491
- The full UNet model with attention and timestep embedding.
492
- :param in_channels: channels in the input Tensor.
493
- :param model_channels: base channel count for the model.
494
- :param out_channels: channels in the output Tensor.
495
- :param num_res_blocks: number of residual blocks per downsample.
496
- :param attention_resolutions: a collection of downsample rates at which
497
- attention will take place. May be a set, list, or tuple.
498
- For example, if this contains 4, then at 4x downsampling, attention
499
- will be used.
500
- :param dropout: the dropout probability.
501
- :param channel_mult: channel multiplier for each level of the UNet.
502
- :param conv_resample: if True, use learned convolutions for upsampling and
503
- downsampling.
504
- :param dims: determines if the signal is 1D, 2D, or 3D.
505
- :param num_classes: if specified (as an int), then this model will be
506
- class-conditional with `num_classes` classes.
507
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
508
- :param num_heads: the number of attention heads in each attention layer.
509
- :param num_heads_channels: if specified, ignore num_heads and instead use
510
- a fixed channel width per attention head.
511
- :param num_heads_upsample: works with num_heads to set a different number
512
- of heads for upsampling. Deprecated.
513
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
514
- :param resblock_updown: use residual blocks for up/downsampling.
515
- :param use_new_attention_order: use a different attention pattern for potentially
516
- increased efficiency.
517
- """
518
-
519
- def __init__(
520
- self,
521
- image_size,
522
- in_channels,
523
- model_channels,
524
- out_channels,
525
- num_res_blocks,
526
- attention_resolutions,
527
- dropout=0,
528
- channel_mult=(1, 2, 4, 8),
529
- conv_resample=True,
530
- dims=2,
531
- num_classes=None,
532
- use_checkpoint=False,
533
- use_fp16=False,
534
- num_heads=-1,
535
- num_head_channels=-1,
536
- num_heads_upsample=-1,
537
- use_scale_shift_norm=False,
538
- resblock_updown=False,
539
- use_new_attention_order=False,
540
- use_spatial_transformer=False, # custom transformer support
541
- transformer_depth=1, # custom transformer support
542
- context_dim=None, # custom transformer support
543
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
544
- legacy=True,
545
- ):
546
- super().__init__()
547
- if use_spatial_transformer:
548
- assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
549
-
550
- if context_dim is not None:
551
- assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
552
- from omegaconf.listconfig import ListConfig
553
- if type(context_dim) == ListConfig:
554
- context_dim = list(context_dim)
555
-
556
- if num_heads_upsample == -1:
557
- num_heads_upsample = num_heads
558
-
559
- if num_heads == -1:
560
- assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
561
-
562
- if num_head_channels == -1:
563
- assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
564
-
565
- self.image_size = image_size
566
- self.in_channels = in_channels
567
- self.model_channels = model_channels
568
- self.out_channels = out_channels
569
- self.num_res_blocks = num_res_blocks
570
- self.attention_resolutions = attention_resolutions
571
- self.dropout = dropout
572
- self.channel_mult = channel_mult
573
- self.conv_resample = conv_resample
574
- self.num_classes = num_classes
575
- self.use_checkpoint = use_checkpoint
576
- self.dtype = th.float16 if use_fp16 else th.float32
577
- self.num_heads = num_heads
578
- self.num_head_channels = num_head_channels
579
- self.num_heads_upsample = num_heads_upsample
580
- self.predict_codebook_ids = n_embed is not None
581
-
582
- time_embed_dim = model_channels * 4
583
- self.time_embed = nn.Sequential(
584
- linear(model_channels, time_embed_dim),
585
- nn.SiLU(),
586
- linear(time_embed_dim, time_embed_dim),
587
- )
588
-
589
- if self.num_classes is not None:
590
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
591
-
592
- self.input_blocks = nn.ModuleList(
593
- [
594
- TimestepEmbedSequential(
595
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
596
- )
597
- ]
598
- )
599
-
600
- self.mask_blocks = TimestepEmbedSequential(
601
- conv_nd(dims, in_channels, model_channels, 3, padding=1),
602
- ResBlockWithoutEmb(
603
- model_channels,
604
- dropout,
605
- dims=dims,
606
- use_checkpoint=use_checkpoint,
607
- use_scale_shift_norm=use_scale_shift_norm,
608
- ),
609
- AttentionBlock(
610
- model_channels,
611
- use_checkpoint=use_checkpoint,
612
- num_heads=num_heads,
613
- num_head_channels=model_channels//num_heads,
614
- use_new_attention_order=use_new_attention_order,
615
- ),
616
- ResBlockWithoutEmb(
617
- model_channels,
618
- dropout,
619
- dims=dims,
620
- use_checkpoint=use_checkpoint,
621
- use_scale_shift_norm=use_scale_shift_norm,
622
- ),
623
- nn.Sequential(
624
- normalization(model_channels),
625
- nn.SiLU(),
626
- zero_module(conv_nd(dims, model_channels, 1, 3, padding=1)),
627
- )
628
- )
629
-
630
- self._feature_size = model_channels
631
- input_block_chans = [model_channels]
632
- ch = model_channels
633
- ds = 1
634
- for level, mult in enumerate(channel_mult):
635
- for _ in range(num_res_blocks):
636
- layers = [
637
- ResBlock(
638
- ch,
639
- time_embed_dim,
640
- dropout,
641
- out_channels=mult * model_channels,
642
- dims=dims,
643
- use_checkpoint=use_checkpoint,
644
- use_scale_shift_norm=use_scale_shift_norm,
645
- )
646
- ]
647
- ch = mult * model_channels
648
- if ds in attention_resolutions:
649
- if num_head_channels == -1:
650
- dim_head = ch // num_heads
651
- else:
652
- num_heads = ch // num_head_channels
653
- dim_head = num_head_channels
654
- if legacy:
655
- #num_heads = 1
656
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
657
- layers.append(
658
- AttentionBlock(
659
- ch,
660
- use_checkpoint=use_checkpoint,
661
- num_heads=num_heads,
662
- num_head_channels=dim_head,
663
- use_new_attention_order=use_new_attention_order,
664
- ) if not use_spatial_transformer else SpatialTransformer(
665
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
666
- )
667
- )
668
- self.input_blocks.append(TimestepEmbedSequential(*layers))
669
- self._feature_size += ch
670
- input_block_chans.append(ch)
671
- if level != len(channel_mult) - 1:
672
- out_ch = ch
673
- self.input_blocks.append(
674
- TimestepEmbedSequential(
675
- ResBlock(
676
- ch,
677
- time_embed_dim,
678
- dropout,
679
- out_channels=out_ch,
680
- dims=dims,
681
- use_checkpoint=use_checkpoint,
682
- use_scale_shift_norm=use_scale_shift_norm,
683
- down=True,
684
- )
685
- if resblock_updown
686
- else Downsample(
687
- ch, conv_resample, dims=dims, out_channels=out_ch
688
- )
689
- )
690
- )
691
- ch = out_ch
692
- input_block_chans.append(ch)
693
- ds *= 2
694
- self._feature_size += ch
695
-
696
- if num_head_channels == -1:
697
- dim_head = ch // num_heads
698
- else:
699
- num_heads = ch // num_head_channels
700
- dim_head = num_head_channels
701
- if legacy:
702
- #num_heads = 1
703
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
704
- self.middle_block = TimestepEmbedSequential(
705
- ResBlock(
706
- ch,
707
- time_embed_dim,
708
- dropout,
709
- dims=dims,
710
- use_checkpoint=use_checkpoint,
711
- use_scale_shift_norm=use_scale_shift_norm,
712
- ),
713
- AttentionBlock(
714
- ch,
715
- use_checkpoint=use_checkpoint,
716
- num_heads=num_heads,
717
- num_head_channels=dim_head,
718
- use_new_attention_order=use_new_attention_order,
719
- ) if not use_spatial_transformer else SpatialTransformer(
720
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
721
- ),
722
- ResBlock(
723
- ch,
724
- time_embed_dim,
725
- dropout,
726
- dims=dims,
727
- use_checkpoint=use_checkpoint,
728
- use_scale_shift_norm=use_scale_shift_norm,
729
- ),
730
- )
731
- self._feature_size += ch
732
-
733
- self.output_blocks = nn.ModuleList([])
734
- for level, mult in list(enumerate(channel_mult))[::-1]:
735
- for i in range(num_res_blocks + 1):
736
- ich = input_block_chans.pop()
737
- layers = [
738
- ResBlock(
739
- ch + ich,
740
- time_embed_dim,
741
- dropout,
742
- out_channels=model_channels * mult,
743
- dims=dims,
744
- use_checkpoint=use_checkpoint,
745
- use_scale_shift_norm=use_scale_shift_norm,
746
- )
747
- ]
748
- ch = model_channels * mult
749
- if ds in attention_resolutions:
750
- if num_head_channels == -1:
751
- dim_head = ch // num_heads
752
- else:
753
- num_heads = ch // num_head_channels
754
- dim_head = num_head_channels
755
- if legacy:
756
- #num_heads = 1
757
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
758
- layers.append(
759
- AttentionBlock(
760
- ch,
761
- use_checkpoint=use_checkpoint,
762
- num_heads=num_heads_upsample,
763
- num_head_channels=dim_head,
764
- use_new_attention_order=use_new_attention_order,
765
- ) if not use_spatial_transformer else SpatialTransformer(
766
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
767
- )
768
- )
769
- if level and i == num_res_blocks:
770
- out_ch = ch
771
- layers.append(
772
- ResBlock(
773
- ch,
774
- time_embed_dim,
775
- dropout,
776
- out_channels=out_ch,
777
- dims=dims,
778
- use_checkpoint=use_checkpoint,
779
- use_scale_shift_norm=use_scale_shift_norm,
780
- up=True,
781
- )
782
- if resblock_updown
783
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
784
- )
785
- ds //= 2
786
- self.output_blocks.append(TimestepEmbedSequential(*layers))
787
- self._feature_size += ch
788
-
789
- self.out_mask = None
790
-
791
- self.out = nn.Sequential(
792
- normalization(ch),
793
- nn.SiLU(),
794
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
795
- )
796
- if self.predict_codebook_ids:
797
- self.id_predictor = nn.Sequential(
798
- normalization(ch),
799
- conv_nd(dims, model_channels, n_embed, 1),
800
- #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
801
- )
802
-
803
- def convert_to_fp16(self):
804
- """
805
- Convert the torso of the model to float16.
806
- """
807
- self.input_blocks.apply(convert_module_to_f16)
808
- self.middle_block.apply(convert_module_to_f16)
809
- self.output_blocks.apply(convert_module_to_f16)
810
-
811
- def convert_to_fp32(self):
812
- """
813
- Convert the torso of the model to float32.
814
- """
815
- self.input_blocks.apply(convert_module_to_f32)
816
- self.middle_block.apply(convert_module_to_f32)
817
- self.output_blocks.apply(convert_module_to_f32)
818
-
819
- def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
820
- """
821
- Apply the model to an input batch.
822
- :param x: an [N x C x ...] Tensor of inputs.
823
- :param timesteps: a 1-D batch of timesteps.
824
- :param context: conditioning plugged in via crossattn
825
- :param y: an [N] Tensor of labels, if class-conditional.
826
- :return: an [N x C x ...] Tensor of outputs.
827
- """
828
- assert (y is not None) == (
829
- self.num_classes is not None
830
- ), "must specify y if and only if the model is class-conditional"
831
- hs = []
832
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
833
- emb = self.time_embed(t_emb)
834
-
835
- if self.num_classes is not None:
836
- assert y.shape == (x.shape[0],)
837
- emb = emb + self.label_emb(y)
838
-
839
- # mask blocks
840
- m = x.type(self.dtype)
841
- m = self.mask_blocks(m, emb, context)
842
-
843
- # unet
844
- h = x.type(self.dtype)
845
- for module in self.input_blocks:
846
- h = module(h, emb, context)
847
- hs.append(h)
848
- h = self.middle_block(h, emb, context)
849
- for module in self.output_blocks:
850
- h = th.cat([h, hs.pop()], dim=1)
851
- h = module(h, emb, context)
852
- h = h.type(x.dtype)
853
- if self.predict_codebook_ids:
854
- return self.id_predictor(h)
855
- else:
856
- return self.out(h), m
857
-
858
- def decode_mask(self, x, timesteps=None, context=None, y=None,**kwargs):
859
- """
860
- Apply the model to an input batch.
861
- :param x: an [N x C x ...] Tensor of inputs.
862
- :param timesteps: a 1-D batch of timesteps.
863
- :param context: conditioning plugged in via crossattn
864
- :param y: an [N] Tensor of labels, if class-conditional.
865
- :return: an [N x C x ...] Tensor of outputs.
866
- """
867
-
868
- # mask blocks
869
- m = x.type(self.dtype)
870
- m = self.mask_blocks(m, None, context)
871
-
872
- return m
873
-
874
-
875
- class EncoderUNetModel(nn.Module):
876
- """
877
- The half UNet model with attention and timestep embedding.
878
- For usage, see UNet.
879
- """
880
-
881
- def __init__(
882
- self,
883
- image_size,
884
- in_channels,
885
- model_channels,
886
- out_channels,
887
- num_res_blocks,
888
- attention_resolutions,
889
- dropout=0,
890
- channel_mult=(1, 2, 4, 8),
891
- conv_resample=True,
892
- dims=2,
893
- use_checkpoint=False,
894
- use_fp16=False,
895
- num_heads=1,
896
- num_head_channels=-1,
897
- num_heads_upsample=-1,
898
- use_scale_shift_norm=False,
899
- resblock_updown=False,
900
- use_new_attention_order=False,
901
- pool="adaptive",
902
- *args,
903
- **kwargs
904
- ):
905
- super().__init__()
906
-
907
- if num_heads_upsample == -1:
908
- num_heads_upsample = num_heads
909
-
910
- self.in_channels = in_channels
911
- self.model_channels = model_channels
912
- self.out_channels = out_channels
913
- self.num_res_blocks = num_res_blocks
914
- self.attention_resolutions = attention_resolutions
915
- self.dropout = dropout
916
- self.channel_mult = channel_mult
917
- self.conv_resample = conv_resample
918
- self.use_checkpoint = use_checkpoint
919
- self.dtype = th.float16 if use_fp16 else th.float32
920
- self.num_heads = num_heads
921
- self.num_head_channels = num_head_channels
922
- self.num_heads_upsample = num_heads_upsample
923
-
924
- time_embed_dim = model_channels * 4
925
- self.time_embed = nn.Sequential(
926
- linear(model_channels, time_embed_dim),
927
- nn.SiLU(),
928
- linear(time_embed_dim, time_embed_dim),
929
- )
930
-
931
- self.input_blocks = nn.ModuleList(
932
- [
933
- TimestepEmbedSequential(
934
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
935
- )
936
- ]
937
- )
938
- self._feature_size = model_channels
939
- input_block_chans = [model_channels]
940
- ch = model_channels
941
- ds = 1
942
- for level, mult in enumerate(channel_mult):
943
- for _ in range(num_res_blocks):
944
- layers = [
945
- ResBlock(
946
- ch,
947
- time_embed_dim,
948
- dropout,
949
- out_channels=mult * model_channels,
950
- dims=dims,
951
- use_checkpoint=use_checkpoint,
952
- use_scale_shift_norm=use_scale_shift_norm,
953
- )
954
- ]
955
- ch = mult * model_channels
956
- if ds in attention_resolutions:
957
- layers.append(
958
- AttentionBlock(
959
- ch,
960
- use_checkpoint=use_checkpoint,
961
- num_heads=num_heads,
962
- num_head_channels=num_head_channels,
963
- use_new_attention_order=use_new_attention_order,
964
- )
965
- )
966
- self.input_blocks.append(TimestepEmbedSequential(*layers))
967
- self._feature_size += ch
968
- input_block_chans.append(ch)
969
- if level != len(channel_mult) - 1:
970
- out_ch = ch
971
- self.input_blocks.append(
972
- TimestepEmbedSequential(
973
- ResBlock(
974
- ch,
975
- time_embed_dim,
976
- dropout,
977
- out_channels=out_ch,
978
- dims=dims,
979
- use_checkpoint=use_checkpoint,
980
- use_scale_shift_norm=use_scale_shift_norm,
981
- down=True,
982
- )
983
- if resblock_updown
984
- else Downsample(
985
- ch, conv_resample, dims=dims, out_channels=out_ch
986
- )
987
- )
988
- )
989
- ch = out_ch
990
- input_block_chans.append(ch)
991
- ds *= 2
992
- self._feature_size += ch
993
-
994
- self.middle_block = TimestepEmbedSequential(
995
- ResBlock(
996
- ch,
997
- time_embed_dim,
998
- dropout,
999
- dims=dims,
1000
- use_checkpoint=use_checkpoint,
1001
- use_scale_shift_norm=use_scale_shift_norm,
1002
- ),
1003
- AttentionBlock(
1004
- ch,
1005
- use_checkpoint=use_checkpoint,
1006
- num_heads=num_heads,
1007
- num_head_channels=num_head_channels,
1008
- use_new_attention_order=use_new_attention_order,
1009
- ),
1010
- ResBlock(
1011
- ch,
1012
- time_embed_dim,
1013
- dropout,
1014
- dims=dims,
1015
- use_checkpoint=use_checkpoint,
1016
- use_scale_shift_norm=use_scale_shift_norm,
1017
- ),
1018
- )
1019
- self._feature_size += ch
1020
- self.pool = pool
1021
- if pool == "adaptive":
1022
- self.out = nn.Sequential(
1023
- normalization(ch),
1024
- nn.SiLU(),
1025
- nn.AdaptiveAvgPool2d((1, 1)),
1026
- zero_module(conv_nd(dims, ch, out_channels, 1)),
1027
- nn.Flatten(),
1028
- )
1029
- elif pool == "attention":
1030
- assert num_head_channels != -1
1031
- self.out = nn.Sequential(
1032
- normalization(ch),
1033
- nn.SiLU(),
1034
- AttentionPool2d(
1035
- (image_size // ds), ch, num_head_channels, out_channels
1036
- ),
1037
- )
1038
- elif pool == "spatial":
1039
- self.out = nn.Sequential(
1040
- nn.Linear(self._feature_size, 2048),
1041
- nn.ReLU(),
1042
- nn.Linear(2048, self.out_channels),
1043
- )
1044
- elif pool == "spatial_v2":
1045
- self.out = nn.Sequential(
1046
- nn.Linear(self._feature_size, 2048),
1047
- normalization(2048),
1048
- nn.SiLU(),
1049
- nn.Linear(2048, self.out_channels),
1050
- )
1051
- else:
1052
- raise NotImplementedError(f"Unexpected {pool} pooling")
1053
-
1054
- def convert_to_fp16(self):
1055
- """
1056
- Convert the torso of the model to float16.
1057
- """
1058
- self.input_blocks.apply(convert_module_to_f16)
1059
- self.middle_block.apply(convert_module_to_f16)
1060
-
1061
- def convert_to_fp32(self):
1062
- """
1063
- Convert the torso of the model to float32.
1064
- """
1065
- self.input_blocks.apply(convert_module_to_f32)
1066
- self.middle_block.apply(convert_module_to_f32)
1067
-
1068
- def forward(self, x, timesteps):
1069
- """
1070
- Apply the model to an input batch.
1071
- :param x: an [N x C x ...] Tensor of inputs.
1072
- :param timesteps: a 1-D batch of timesteps.
1073
- :return: an [N x K] Tensor of outputs.
1074
- """
1075
- emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
1076
-
1077
- results = []
1078
- h = x.type(self.dtype)
1079
- for module in self.input_blocks:
1080
- h = module(h, emb)
1081
- if self.pool.startswith("spatial"):
1082
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
1083
- h = self.middle_block(h, emb)
1084
- if self.pool.startswith("spatial"):
1085
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
1086
- h = th.cat(results, axis=-1)
1087
- return self.out(h)
1088
- else:
1089
- h = h.type(x.dtype)
1090
- return self.out(h)
1091
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable_diffusion/ldm/modules/diffusionmodules/openaimodel_pam_test.py DELETED
@@ -1,1040 +0,0 @@
1
- from abc import abstractmethod
2
- from functools import partial
3
- import math
4
- from typing import Iterable
5
-
6
- import numpy as np
7
- import torch as th
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
-
11
- from ldm.modules.diffusionmodules.util import (
12
- checkpoint,
13
- conv_nd,
14
- linear,
15
- avg_pool_nd,
16
- zero_module,
17
- normalization,
18
- timestep_embedding,
19
- )
20
- from ldm.modules.attention import SpatialTransformer
21
-
22
-
23
- # dummy replace
24
- def convert_module_to_f16(x):
25
- pass
26
-
27
- def convert_module_to_f32(x):
28
- pass
29
-
30
-
31
- ## go
32
- class AttentionPool2d(nn.Module):
33
- """
34
- Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
35
- """
36
-
37
- def __init__(
38
- self,
39
- spacial_dim: int,
40
- embed_dim: int,
41
- num_heads_channels: int,
42
- output_dim: int = None,
43
- ):
44
- super().__init__()
45
- self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
46
- self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
47
- self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
48
- self.num_heads = embed_dim // num_heads_channels
49
- self.attention = QKVAttention(self.num_heads)
50
-
51
- def forward(self, x):
52
- b, c, *_spatial = x.shape
53
- x = x.reshape(b, c, -1) # NC(HW)
54
- x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
55
- x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
56
- x = self.qkv_proj(x)
57
- x = self.attention(x)
58
- x = self.c_proj(x)
59
- return x[:, :, 0]
60
-
61
-
62
- class TimestepBlock(nn.Module):
63
- """
64
- Any module where forward() takes timestep embeddings as a second argument.
65
- """
66
-
67
- @abstractmethod
68
- def forward(self, x, emb):
69
- """
70
- Apply the module to `x` given `emb` timestep embeddings.
71
- """
72
-
73
-
74
- class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
75
- """
76
- A sequential module that passes timestep embeddings to the children that
77
- support it as an extra input.
78
- """
79
-
80
- def forward(self, x, emb, context=None):
81
- for layer in self:
82
- if isinstance(layer, TimestepBlock):
83
- x = layer(x, emb)
84
- elif isinstance(layer, SpatialTransformer):
85
- x = layer(x, context)
86
- else:
87
- x = layer(x)
88
- return x
89
-
90
-
91
- class Upsample(nn.Module):
92
- """
93
- An upsampling layer with an optional convolution.
94
- :param channels: channels in the inputs and outputs.
95
- :param use_conv: a bool determining if a convolution is applied.
96
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
97
- upsampling occurs in the inner-two dimensions.
98
- """
99
-
100
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
101
- super().__init__()
102
- self.channels = channels
103
- self.out_channels = out_channels or channels
104
- self.use_conv = use_conv
105
- self.dims = dims
106
- if use_conv:
107
- self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
108
-
109
- def forward(self, x):
110
- assert x.shape[1] == self.channels
111
- if self.dims == 3:
112
- x = F.interpolate(
113
- x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
114
- )
115
- else:
116
- x = F.interpolate(x, scale_factor=2, mode="nearest")
117
- if self.use_conv:
118
- x = self.conv(x)
119
- return x
120
-
121
- class TransposedUpsample(nn.Module):
122
- 'Learned 2x upsampling without padding'
123
- def __init__(self, channels, out_channels=None, ks=5):
124
- super().__init__()
125
- self.channels = channels
126
- self.out_channels = out_channels or channels
127
-
128
- self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
129
-
130
- def forward(self,x):
131
- return self.up(x)
132
-
133
-
134
- class Downsample(nn.Module):
135
- """
136
- A downsampling layer with an optional convolution.
137
- :param channels: channels in the inputs and outputs.
138
- :param use_conv: a bool determining if a convolution is applied.
139
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
140
- downsampling occurs in the inner-two dimensions.
141
- """
142
-
143
- def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
144
- super().__init__()
145
- self.channels = channels
146
- self.out_channels = out_channels or channels
147
- self.use_conv = use_conv
148
- self.dims = dims
149
- stride = 2 if dims != 3 else (1, 2, 2)
150
- if use_conv:
151
- self.op = conv_nd(
152
- dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
153
- )
154
- else:
155
- assert self.channels == self.out_channels
156
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
157
-
158
- def forward(self, x):
159
- assert x.shape[1] == self.channels
160
- return self.op(x)
161
-
162
-
163
- class ResBlock(TimestepBlock):
164
- """
165
- A residual block that can optionally change the number of channels.
166
- :param channels: the number of input channels.
167
- :param emb_channels: the number of timestep embedding channels.
168
- :param dropout: the rate of dropout.
169
- :param out_channels: if specified, the number of out channels.
170
- :param use_conv: if True and out_channels is specified, use a spatial
171
- convolution instead of a smaller 1x1 convolution to change the
172
- channels in the skip connection.
173
- :param dims: determines if the signal is 1D, 2D, or 3D.
174
- :param use_checkpoint: if True, use gradient checkpointing on this module.
175
- :param up: if True, use this block for upsampling.
176
- :param down: if True, use this block for downsampling.
177
- """
178
-
179
- def __init__(
180
- self,
181
- channels,
182
- emb_channels,
183
- dropout,
184
- out_channels=None,
185
- use_conv=False,
186
- use_scale_shift_norm=False,
187
- dims=2,
188
- use_checkpoint=False,
189
- up=False,
190
- down=False,
191
- ):
192
- super().__init__()
193
- self.channels = channels
194
- self.emb_channels = emb_channels
195
- self.dropout = dropout
196
- self.out_channels = out_channels or channels
197
- self.use_conv = use_conv
198
- self.use_checkpoint = use_checkpoint
199
- self.use_scale_shift_norm = use_scale_shift_norm
200
-
201
- self.in_layers = nn.Sequential(
202
- normalization(channels),
203
- nn.SiLU(),
204
- conv_nd(dims, channels, self.out_channels, 3, padding=1),
205
- )
206
-
207
- self.updown = up or down
208
-
209
- if up:
210
- self.h_upd = Upsample(channels, False, dims)
211
- self.x_upd = Upsample(channels, False, dims)
212
- elif down:
213
- self.h_upd = Downsample(channels, False, dims)
214
- self.x_upd = Downsample(channels, False, dims)
215
- else:
216
- self.h_upd = self.x_upd = nn.Identity()
217
-
218
- self.emb_layers = nn.Sequential(
219
- nn.SiLU(),
220
- linear(
221
- emb_channels,
222
- 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
223
- ),
224
- )
225
- self.out_layers = nn.Sequential(
226
- normalization(self.out_channels),
227
- nn.SiLU(),
228
- nn.Dropout(p=dropout),
229
- zero_module(
230
- conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
231
- ),
232
- )
233
-
234
- if self.out_channels == channels:
235
- self.skip_connection = nn.Identity()
236
- elif use_conv:
237
- self.skip_connection = conv_nd(
238
- dims, channels, self.out_channels, 3, padding=1
239
- )
240
- else:
241
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
242
-
243
- def forward(self, x, emb):
244
- """
245
- Apply the block to a Tensor, conditioned on a timestep embedding.
246
- :param x: an [N x C x ...] Tensor of features.
247
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
248
- :return: an [N x C x ...] Tensor of outputs.
249
- """
250
- return checkpoint(
251
- self._forward, (x, emb), self.parameters(), self.use_checkpoint
252
- )
253
-
254
-
255
- def _forward(self, x, emb):
256
- if self.updown:
257
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
258
- h = in_rest(x)
259
- h = self.h_upd(h)
260
- x = self.x_upd(x)
261
- h = in_conv(h)
262
- else:
263
- h = self.in_layers(x)
264
- emb_out = self.emb_layers(emb).type(h.dtype)
265
- while len(emb_out.shape) < len(h.shape):
266
- emb_out = emb_out[..., None]
267
- if self.use_scale_shift_norm:
268
- out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
269
- scale, shift = th.chunk(emb_out, 2, dim=1)
270
- h = out_norm(h) * (1 + scale) + shift
271
- h = out_rest(h)
272
- else:
273
- h = h + emb_out
274
- h = self.out_layers(h)
275
- return self.skip_connection(x) + h
276
-
277
-
278
- class AttentionBlock(nn.Module):
279
- """
280
- An attention block that allows spatial positions to attend to each other.
281
- Originally ported from here, but adapted to the N-d case.
282
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
283
- """
284
-
285
- def __init__(
286
- self,
287
- channels,
288
- num_heads=1,
289
- num_head_channels=-1,
290
- use_checkpoint=False,
291
- use_new_attention_order=False,
292
- ):
293
- super().__init__()
294
- self.channels = channels
295
- if num_head_channels == -1:
296
- self.num_heads = num_heads
297
- else:
298
- assert (
299
- channels % num_head_channels == 0
300
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
301
- self.num_heads = channels // num_head_channels
302
- self.use_checkpoint = use_checkpoint
303
- self.norm = normalization(channels)
304
- self.qkv = conv_nd(1, channels, channels * 3, 1)
305
- if use_new_attention_order:
306
- # split qkv before split heads
307
- self.attention = QKVAttention(self.num_heads)
308
- else:
309
- # split heads before split qkv
310
- self.attention = QKVAttentionLegacy(self.num_heads)
311
-
312
- self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
313
-
314
- def forward(self, x):
315
- return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
316
- #return pt_checkpoint(self._forward, x) # pytorch
317
-
318
- def _forward(self, x):
319
- b, c, *spatial = x.shape
320
- x = x.reshape(b, c, -1)
321
- qkv = self.qkv(self.norm(x))
322
- h = self.attention(qkv)
323
- h = self.proj_out(h)
324
- return (x + h).reshape(b, c, *spatial)
325
-
326
-
327
- def count_flops_attn(model, _x, y):
328
- """
329
- A counter for the `thop` package to count the operations in an
330
- attention operation.
331
- Meant to be used like:
332
- macs, params = thop.profile(
333
- model,
334
- inputs=(inputs, timestamps),
335
- custom_ops={QKVAttention: QKVAttention.count_flops},
336
- )
337
- """
338
- b, c, *spatial = y[0].shape
339
- num_spatial = int(np.prod(spatial))
340
- # We perform two matmuls with the same number of ops.
341
- # The first computes the weight matrix, the second computes
342
- # the combination of the value vectors.
343
- matmul_ops = 2 * b * (num_spatial ** 2) * c
344
- model.total_ops += th.DoubleTensor([matmul_ops])
345
-
346
-
347
- class QKVAttentionLegacy(nn.Module):
348
- """
349
- A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
350
- """
351
-
352
- def __init__(self, n_heads):
353
- super().__init__()
354
- self.n_heads = n_heads
355
-
356
- def forward(self, qkv):
357
- """
358
- Apply QKV attention.
359
- :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
360
- :return: an [N x (H * C) x T] tensor after attention.
361
- """
362
- bs, width, length = qkv.shape
363
- assert width % (3 * self.n_heads) == 0
364
- ch = width // (3 * self.n_heads)
365
- q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
366
- scale = 1 / math.sqrt(math.sqrt(ch))
367
- weight = th.einsum(
368
- "bct,bcs->bts", q * scale, k * scale
369
- ) # More stable with f16 than dividing afterwards
370
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
371
- a = th.einsum("bts,bcs->bct", weight, v)
372
- return a.reshape(bs, -1, length)
373
-
374
- @staticmethod
375
- def count_flops(model, _x, y):
376
- return count_flops_attn(model, _x, y)
377
-
378
-
379
- class QKVAttention(nn.Module):
380
- """
381
- A module which performs QKV attention and splits in a different order.
382
- """
383
-
384
- def __init__(self, n_heads):
385
- super().__init__()
386
- self.n_heads = n_heads
387
-
388
- def forward(self, qkv):
389
- """
390
- Apply QKV attention.
391
- :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
392
- :return: an [N x (H * C) x T] tensor after attention.
393
- """
394
- bs, width, length = qkv.shape
395
- assert width % (3 * self.n_heads) == 0
396
- ch = width // (3 * self.n_heads)
397
- q, k, v = qkv.chunk(3, dim=1)
398
- scale = 1 / math.sqrt(math.sqrt(ch))
399
- weight = th.einsum(
400
- "bct,bcs->bts",
401
- (q * scale).view(bs * self.n_heads, ch, length),
402
- (k * scale).view(bs * self.n_heads, ch, length),
403
- ) # More stable with f16 than dividing afterwards
404
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
405
- a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
406
- return a.reshape(bs, -1, length)
407
-
408
- @staticmethod
409
- def count_flops(model, _x, y):
410
- return count_flops_attn(model, _x, y)
411
-
412
-
413
- class UNetModel(nn.Module):
414
- """
415
- The full UNet model with attention and timestep embedding.
416
- :param in_channels: channels in the input Tensor.
417
- :param model_channels: base channel count for the model.
418
- :param out_channels: channels in the output Tensor.
419
- :param num_res_blocks: number of residual blocks per downsample.
420
- :param attention_resolutions: a collection of downsample rates at which
421
- attention will take place. May be a set, list, or tuple.
422
- For example, if this contains 4, then at 4x downsampling, attention
423
- will be used.
424
- :param dropout: the dropout probability.
425
- :param channel_mult: channel multiplier for each level of the UNet.
426
- :param conv_resample: if True, use learned convolutions for upsampling and
427
- downsampling.
428
- :param dims: determines if the signal is 1D, 2D, or 3D.
429
- :param num_classes: if specified (as an int), then this model will be
430
- class-conditional with `num_classes` classes.
431
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
432
- :param num_heads: the number of attention heads in each attention layer.
433
- :param num_heads_channels: if specified, ignore num_heads and instead use
434
- a fixed channel width per attention head.
435
- :param num_heads_upsample: works with num_heads to set a different number
436
- of heads for upsampling. Deprecated.
437
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
438
- :param resblock_updown: use residual blocks for up/downsampling.
439
- :param use_new_attention_order: use a different attention pattern for potentially
440
- increased efficiency.
441
- """
442
-
443
- def __init__(
444
- self,
445
- image_size,
446
- in_channels,
447
- in_mask_channels,
448
- model_channels,
449
- out_channels,
450
- num_res_blocks,
451
- attention_resolutions,
452
- dropout=0,
453
- channel_mult=(1, 2, 4, 8),
454
- conv_resample=True,
455
- dims=2,
456
- num_classes=None,
457
- use_checkpoint=False,
458
- use_fp16=False,
459
- num_heads=-1,
460
- num_head_channels=-1,
461
- num_heads_upsample=-1,
462
- use_scale_shift_norm=False,
463
- resblock_updown=False,
464
- use_new_attention_order=False,
465
- use_spatial_transformer=False, # custom transformer support
466
- transformer_depth=1, # custom transformer support
467
- context_dim=None, # custom transformer support
468
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
469
- legacy=True,
470
- independent_blocks_num=1, # custom support for independent blocks
471
- ):
472
- super().__init__()
473
- if use_spatial_transformer:
474
- assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
475
-
476
- if context_dim is not None:
477
- assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
478
- from omegaconf.listconfig import ListConfig
479
- if type(context_dim) == ListConfig:
480
- context_dim = list(context_dim)
481
-
482
- if num_heads_upsample == -1:
483
- num_heads_upsample = num_heads
484
-
485
- if num_heads == -1:
486
- assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
487
-
488
- if num_head_channels == -1:
489
- assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
490
-
491
- self.image_size = image_size
492
- self.in_channels = in_channels
493
- self.in_mask_channels = in_mask_channels
494
- self.model_channels = model_channels
495
- self.out_channels = out_channels
496
- self.num_res_blocks = num_res_blocks
497
- self.attention_resolutions = attention_resolutions
498
- self.dropout = dropout
499
- self.channel_mult = channel_mult
500
- self.conv_resample = conv_resample
501
- self.num_classes = num_classes
502
- self.use_checkpoint = use_checkpoint
503
- self.dtype = th.float16 if use_fp16 else th.float32
504
- self.num_heads = num_heads
505
- self.num_head_channels = num_head_channels
506
- self.num_heads_upsample = num_heads_upsample
507
- self.predict_codebook_ids = n_embed is not None
508
- self.independent_blocks_num = independent_blocks_num
509
- assert self.independent_blocks_num > 0 and self.independent_blocks_num <= len(channel_mult), 'Number of independent blocks should be between 1 and the number of blocks'
510
-
511
- time_embed_dim = model_channels * 4
512
- self.time_embed = nn.Sequential(
513
- linear(model_channels, time_embed_dim),
514
- nn.SiLU(),
515
- linear(time_embed_dim, time_embed_dim),
516
- )
517
-
518
- if self.num_classes is not None:
519
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
520
-
521
- self.input_blocks = nn.ModuleList(
522
- [
523
- TimestepEmbedSequential(
524
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
525
- )
526
- ]
527
- )
528
- self.input_blocks_branch_1 = nn.ModuleList(
529
- [
530
- TimestepEmbedSequential(
531
- conv_nd(dims, in_mask_channels, model_channels, 3, padding=1)
532
- )
533
- ]
534
- )
535
- self.input_blocks_branch_1_available = [True]
536
- self._feature_size = model_channels
537
- input_block_chans = [model_channels]
538
- ch = model_channels
539
- ds = 1
540
- for level, mult in enumerate(channel_mult):
541
- for _ in range(num_res_blocks):
542
- layers = [
543
- ResBlock(
544
- ch,
545
- time_embed_dim,
546
- dropout,
547
- out_channels=mult * model_channels,
548
- dims=dims,
549
- use_checkpoint=use_checkpoint,
550
- use_scale_shift_norm=use_scale_shift_norm,
551
- )
552
- ]
553
- ch = mult * model_channels
554
- if ds in attention_resolutions:
555
- if num_head_channels == -1:
556
- dim_head = ch // num_heads
557
- else:
558
- num_heads = ch // num_head_channels
559
- dim_head = num_head_channels
560
- if legacy:
561
- #num_heads = 1
562
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
563
- layers.append(
564
- AttentionBlock(
565
- ch,
566
- use_checkpoint=use_checkpoint,
567
- num_heads=num_heads,
568
- num_head_channels=dim_head,
569
- use_new_attention_order=use_new_attention_order,
570
- ) if not use_spatial_transformer else SpatialTransformer(
571
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
572
- )
573
- )
574
- self.input_blocks.append(TimestepEmbedSequential(*layers))
575
- if level < self.independent_blocks_num:
576
- self.input_blocks_branch_1.append(TimestepEmbedSequential(*layers))
577
- self.input_blocks_branch_1_available.append(True)
578
- else:
579
- self.input_blocks_branch_1.append(nn.Sequential(nn.Identity()))
580
- self.input_blocks_branch_1_available.append(False)
581
- self._feature_size += ch
582
- input_block_chans.append(ch)
583
- if level != len(channel_mult) - 1:
584
- out_ch = ch
585
- self.input_blocks.append(
586
- TimestepEmbedSequential(
587
- ResBlock(
588
- ch,
589
- time_embed_dim,
590
- dropout,
591
- out_channels=out_ch,
592
- dims=dims,
593
- use_checkpoint=use_checkpoint,
594
- use_scale_shift_norm=use_scale_shift_norm,
595
- down=True,
596
- )
597
- if resblock_updown
598
- else Downsample(
599
- ch, conv_resample, dims=dims, out_channels=out_ch
600
- )
601
- )
602
- )
603
- if level < self.independent_blocks_num - 1:
604
- self.input_blocks_branch_1.append(
605
- TimestepEmbedSequential(
606
- ResBlock(
607
- ch,
608
- time_embed_dim,
609
- dropout,
610
- out_channels=out_ch,
611
- dims=dims,
612
- use_checkpoint=use_checkpoint,
613
- use_scale_shift_norm=use_scale_shift_norm,
614
- down=True,
615
- )
616
- if resblock_updown
617
- else Downsample(
618
- ch, conv_resample, dims=dims, out_channels=out_ch
619
- )
620
- )
621
- )
622
- self.input_blocks_branch_1_available.append(True)
623
- else:
624
- self.input_blocks_branch_1.append(nn.Sequential(nn.Identity()))
625
- self.input_blocks_branch_1_available.append(False)
626
- ch = out_ch
627
- input_block_chans.append(ch)
628
- ds *= 2
629
- self._feature_size += ch
630
- if num_head_channels == -1:
631
- dim_head = ch // num_heads
632
- else:
633
- num_heads = ch // num_head_channels
634
- dim_head = num_head_channels
635
- if legacy:
636
- #num_heads = 1
637
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
638
- self.middle_block = TimestepEmbedSequential(
639
- ResBlock(
640
- ch,
641
- time_embed_dim,
642
- dropout,
643
- dims=dims,
644
- use_checkpoint=use_checkpoint,
645
- use_scale_shift_norm=use_scale_shift_norm,
646
- ),
647
- AttentionBlock(
648
- ch,
649
- use_checkpoint=use_checkpoint,
650
- num_heads=num_heads,
651
- num_head_channels=dim_head,
652
- use_new_attention_order=use_new_attention_order,
653
- ) if not use_spatial_transformer else SpatialTransformer(
654
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
655
- ),
656
- ResBlock(
657
- ch,
658
- time_embed_dim,
659
- dropout,
660
- dims=dims,
661
- use_checkpoint=use_checkpoint,
662
- use_scale_shift_norm=use_scale_shift_norm,
663
- ),
664
- )
665
- self._feature_size += ch
666
-
667
- self.output_blocks = nn.ModuleList([])
668
- self.output_blocks_branch_1 = nn.ModuleList([])
669
- self.output_blocks_branch_1_available = []
670
-
671
- for level, mult in list(enumerate(channel_mult))[::-1]:
672
- for i in range(num_res_blocks + 1):
673
- ich = input_block_chans.pop()
674
- layers = [
675
- ResBlock(
676
- ch + ich,
677
- time_embed_dim,
678
- dropout,
679
- out_channels=model_channels * mult,
680
- dims=dims,
681
- use_checkpoint=use_checkpoint,
682
- use_scale_shift_norm=use_scale_shift_norm,
683
- )
684
- ]
685
- ch = model_channels * mult
686
- if ds in attention_resolutions:
687
- if num_head_channels == -1:
688
- dim_head = ch // num_heads
689
- else:
690
- num_heads = ch // num_head_channels
691
- dim_head = num_head_channels
692
- if legacy:
693
- #num_heads = 1
694
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
695
- layers.append(
696
- AttentionBlock(
697
- ch,
698
- use_checkpoint=use_checkpoint,
699
- num_heads=num_heads_upsample,
700
- num_head_channels=dim_head,
701
- use_new_attention_order=use_new_attention_order,
702
- ) if not use_spatial_transformer else SpatialTransformer(
703
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
704
- )
705
- )
706
- if level and i == num_res_blocks:
707
- out_ch = ch
708
- layers.append(
709
- ResBlock(
710
- ch,
711
- time_embed_dim,
712
- dropout,
713
- out_channels=out_ch,
714
- dims=dims,
715
- use_checkpoint=use_checkpoint,
716
- use_scale_shift_norm=use_scale_shift_norm,
717
- up=True,
718
- )
719
- if resblock_updown
720
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
721
- )
722
- ds //= 2
723
- self.output_blocks.append(TimestepEmbedSequential(*layers))
724
- if level < self.independent_blocks_num:
725
- self.output_blocks_branch_1.append(TimestepEmbedSequential(*layers))
726
- self.output_blocks_branch_1_available.append(True)
727
- else:
728
- self.output_blocks_branch_1.append(nn.Sequential(nn.Identity()))
729
- self.output_blocks_branch_1_available.append(False)
730
-
731
- self._feature_size += ch
732
-
733
- self.out = nn.Sequential(
734
- normalization(ch),
735
- nn.SiLU(),
736
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
737
- )
738
- self.out_branch_1 = nn.Sequential(
739
- normalization(ch),
740
- nn.SiLU(),
741
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
742
- )
743
- if self.predict_codebook_ids:
744
- self.id_predictor = nn.Sequential(
745
- normalization(ch),
746
- conv_nd(dims, model_channels, n_embed, 1),
747
- #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
748
- )
749
-
750
-
751
- def convert_to_fp16(self):
752
- """
753
- Convert the torso of the model to float16.
754
- """
755
- self.input_blocks.apply(convert_module_to_f16)
756
- self.middle_block.apply(convert_module_to_f16)
757
- self.output_blocks.apply(convert_module_to_f16)
758
-
759
- def convert_to_fp32(self):
760
- """
761
- Convert the torso of the model to float32.
762
- """
763
- self.input_blocks.apply(convert_module_to_f32)
764
- self.middle_block.apply(convert_module_to_f32)
765
- self.output_blocks.apply(convert_module_to_f32)
766
-
767
- def forward(self, x_0, x_1, timesteps=None, context=None, y=None,**kwargs):
768
- """
769
- Apply the model to an input batch.
770
- :param x_0: an [N x C x ...] Tensor of inputs.
771
- :param x_1: an [N x C x ...] Tensor of inputs.
772
- :param timesteps: a 1-D batch of timesteps.
773
- :param context: conditioning plugged in via crossattn
774
- :param y: an [N] Tensor of labels, if class-conditional.
775
- :return: an [N x C x ...] Tensor of outputs.
776
- """
777
- assert (y is not None) == (
778
- self.num_classes is not None
779
- ), "must specify y if and only if the model is class-conditional"
780
- hs_0 = []
781
- hs_1 = []
782
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
783
- emb = self.time_embed(t_emb)
784
-
785
- if self.num_classes is not None:
786
- assert y.shape == (x.shape[0],)
787
- emb = emb + self.label_emb(y)
788
-
789
- h_0 = x_0.type(self.dtype)
790
- h_1 = x_1.type(self.dtype)
791
- for index, module in enumerate(self.input_blocks):
792
- h_0 = module(h_0, emb, context)
793
-
794
- if self.input_blocks_branch_1_available[index]:
795
- module_branch_1 = self.input_blocks_branch_1[index]
796
- h_1 = module_branch_1(h_1, emb, context)
797
- else:
798
- h_1 = module(h_1, emb, context)
799
- hs_0.append(h_0)
800
- hs_1.append(h_1)
801
-
802
- h_0 = self.middle_block(h_0, emb, context)
803
- h_1 = self.middle_block(h_1, emb, context)
804
-
805
- for index, module in enumerate(self.output_blocks):
806
- h_0 = th.cat([h_0, hs_0.pop()], dim=1)
807
- h_0 = module(h_0, emb, context)
808
-
809
- h_1 = th.cat([h_1, hs_1.pop()], dim=1)
810
- if self.output_blocks_branch_1_available[index]:
811
- module_branch_1 = self.output_blocks_branch_1[index]
812
- h_1 = module_branch_1(h_1, emb, context)
813
- else:
814
- h_1 = module(h_1, emb, context)
815
-
816
- h_0 = h_0.type(x_0.dtype)
817
- h_1 = h_1.type(x_1.dtype)
818
- if self.predict_codebook_ids:
819
- return self.id_predictor(h_0), self.id_predictor(h_1)
820
- else:
821
- return self.out(h_0), self.out_branch_1(h_1)
822
-
823
-
824
- class EncoderUNetModel(nn.Module):
825
- """
826
- The half UNet model with attention and timestep embedding.
827
- For usage, see UNet.
828
- """
829
-
830
- def __init__(
831
- self,
832
- image_size,
833
- in_channels,
834
- model_channels,
835
- out_channels,
836
- num_res_blocks,
837
- attention_resolutions,
838
- dropout=0,
839
- channel_mult=(1, 2, 4, 8),
840
- conv_resample=True,
841
- dims=2,
842
- use_checkpoint=False,
843
- use_fp16=False,
844
- num_heads=1,
845
- num_head_channels=-1,
846
- num_heads_upsample=-1,
847
- use_scale_shift_norm=False,
848
- resblock_updown=False,
849
- use_new_attention_order=False,
850
- pool="adaptive",
851
- *args,
852
- **kwargs
853
- ):
854
- super().__init__()
855
-
856
- if num_heads_upsample == -1:
857
- num_heads_upsample = num_heads
858
-
859
- self.in_channels = in_channels
860
- self.model_channels = model_channels
861
- self.out_channels = out_channels
862
- self.num_res_blocks = num_res_blocks
863
- self.attention_resolutions = attention_resolutions
864
- self.dropout = dropout
865
- self.channel_mult = channel_mult
866
- self.conv_resample = conv_resample
867
- self.use_checkpoint = use_checkpoint
868
- self.dtype = th.float16 if use_fp16 else th.float32
869
- self.num_heads = num_heads
870
- self.num_head_channels = num_head_channels
871
- self.num_heads_upsample = num_heads_upsample
872
-
873
- time_embed_dim = model_channels * 4
874
- self.time_embed = nn.Sequential(
875
- linear(model_channels, time_embed_dim),
876
- nn.SiLU(),
877
- linear(time_embed_dim, time_embed_dim),
878
- )
879
-
880
- self.input_blocks = nn.ModuleList(
881
- [
882
- TimestepEmbedSequential(
883
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
884
- )
885
- ]
886
- )
887
- self._feature_size = model_channels
888
- input_block_chans = [model_channels]
889
- ch = model_channels
890
- ds = 1
891
- for level, mult in enumerate(channel_mult):
892
- for _ in range(num_res_blocks):
893
- layers = [
894
- ResBlock(
895
- ch,
896
- time_embed_dim,
897
- dropout,
898
- out_channels=mult * model_channels,
899
- dims=dims,
900
- use_checkpoint=use_checkpoint,
901
- use_scale_shift_norm=use_scale_shift_norm,
902
- )
903
- ]
904
- ch = mult * model_channels
905
- if ds in attention_resolutions:
906
- layers.append(
907
- AttentionBlock(
908
- ch,
909
- use_checkpoint=use_checkpoint,
910
- num_heads=num_heads,
911
- num_head_channels=num_head_channels,
912
- use_new_attention_order=use_new_attention_order,
913
- )
914
- )
915
- self.input_blocks.append(TimestepEmbedSequential(*layers))
916
- self._feature_size += ch
917
- input_block_chans.append(ch)
918
- if level != len(channel_mult) - 1:
919
- out_ch = ch
920
- self.input_blocks.append(
921
- TimestepEmbedSequential(
922
- ResBlock(
923
- ch,
924
- time_embed_dim,
925
- dropout,
926
- out_channels=out_ch,
927
- dims=dims,
928
- use_checkpoint=use_checkpoint,
929
- use_scale_shift_norm=use_scale_shift_norm,
930
- down=True,
931
- )
932
- if resblock_updown
933
- else Downsample(
934
- ch, conv_resample, dims=dims, out_channels=out_ch
935
- )
936
- )
937
- )
938
- ch = out_ch
939
- input_block_chans.append(ch)
940
- ds *= 2
941
- self._feature_size += ch
942
-
943
- self.middle_block = TimestepEmbedSequential(
944
- ResBlock(
945
- ch,
946
- time_embed_dim,
947
- dropout,
948
- dims=dims,
949
- use_checkpoint=use_checkpoint,
950
- use_scale_shift_norm=use_scale_shift_norm,
951
- ),
952
- AttentionBlock(
953
- ch,
954
- use_checkpoint=use_checkpoint,
955
- num_heads=num_heads,
956
- num_head_channels=num_head_channels,
957
- use_new_attention_order=use_new_attention_order,
958
- ),
959
- ResBlock(
960
- ch,
961
- time_embed_dim,
962
- dropout,
963
- dims=dims,
964
- use_checkpoint=use_checkpoint,
965
- use_scale_shift_norm=use_scale_shift_norm,
966
- ),
967
- )
968
- self._feature_size += ch
969
- self.pool = pool
970
- if pool == "adaptive":
971
- self.out = nn.Sequential(
972
- normalization(ch),
973
- nn.SiLU(),
974
- nn.AdaptiveAvgPool2d((1, 1)),
975
- zero_module(conv_nd(dims, ch, out_channels, 1)),
976
- nn.Flatten(),
977
- )
978
- elif pool == "attention":
979
- assert num_head_channels != -1
980
- self.out = nn.Sequential(
981
- normalization(ch),
982
- nn.SiLU(),
983
- AttentionPool2d(
984
- (image_size // ds), ch, num_head_channels, out_channels
985
- ),
986
- )
987
- elif pool == "spatial":
988
- self.out = nn.Sequential(
989
- nn.Linear(self._feature_size, 2048),
990
- nn.ReLU(),
991
- nn.Linear(2048, self.out_channels),
992
- )
993
- elif pool == "spatial_v2":
994
- self.out = nn.Sequential(
995
- nn.Linear(self._feature_size, 2048),
996
- normalization(2048),
997
- nn.SiLU(),
998
- nn.Linear(2048, self.out_channels),
999
- )
1000
- else:
1001
- raise NotImplementedError(f"Unexpected {pool} pooling")
1002
-
1003
- def convert_to_fp16(self):
1004
- """
1005
- Convert the torso of the model to float16.
1006
- """
1007
- self.input_blocks.apply(convert_module_to_f16)
1008
- self.middle_block.apply(convert_module_to_f16)
1009
-
1010
- def convert_to_fp32(self):
1011
- """
1012
- Convert the torso of the model to float32.
1013
- """
1014
- self.input_blocks.apply(convert_module_to_f32)
1015
- self.middle_block.apply(convert_module_to_f32)
1016
-
1017
- def forward(self, x, timesteps):
1018
- """
1019
- Apply the model to an input batch.
1020
- :param x: an [N x C x ...] Tensor of inputs.
1021
- :param timesteps: a 1-D batch of timesteps.
1022
- :return: an [N x K] Tensor of outputs.
1023
- """
1024
- emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
1025
-
1026
- results = []
1027
- h = x.type(self.dtype)
1028
- for module in self.input_blocks:
1029
- h = module(h, emb)
1030
- if self.pool.startswith("spatial"):
1031
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
1032
- h = self.middle_block(h, emb)
1033
- if self.pool.startswith("spatial"):
1034
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
1035
- h = th.cat(results, axis=-1)
1036
- return self.out(h)
1037
- else:
1038
- h = h.type(x.dtype)
1039
- return self.out(h)
1040
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
video_demo.mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d4f71dce37b7e62ad467ec5d24004e8714be7e76bf634cd610c1935b03501ca6
3
- size 32058066