RamAnanth1 commited on
Commit
9ae63f1
1 Parent(s): 611f831

Update ldm/models/diffusion/ddim.py

Browse files
Files changed (1) hide show
  1. ldm/models/diffusion/ddim.py +5 -5
ldm/models/diffusion/ddim.py CHANGED
@@ -32,14 +32,14 @@ class DDIMSampler(object):
32
  self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
33
 
34
  # calculations for diffusion q(x_t | x_{t-1}) and others
35
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))
36
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))
37
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumpro))
38
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))
39
  self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
40
 
41
  # ddim sampling parameters
42
- ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
43
  ddim_timesteps=self.ddim_timesteps,
44
  eta=ddim_eta,verbose=verbose)
45
  self.register_buffer('ddim_sigmas', ddim_sigmas)
 
32
  self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
33
 
34
  # calculations for diffusion q(x_t | x_{t-1}) and others
35
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
36
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
37
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
38
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
39
  self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
40
 
41
  # ddim sampling parameters
42
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod,
43
  ddim_timesteps=self.ddim_timesteps,
44
  eta=ddim_eta,verbose=verbose)
45
  self.register_buffer('ddim_sigmas', ddim_sigmas)