Spaces:
Build error
Build error
Warvito
commited on
Commit
•
1d1c1e7
1
Parent(s):
6686e5d
Try fix cpu only
Browse files- models/ddim.py +3 -3
models/ddim.py
CHANGED
@@ -68,8 +68,8 @@ class DDIMSampler(object):
|
|
68 |
|
69 |
def register_buffer(self, name, attr):
|
70 |
if type(attr) == torch.Tensor:
|
71 |
-
if attr.device != torch.device("
|
72 |
-
attr = attr.to(torch.device("
|
73 |
setattr(self, name, attr)
|
74 |
|
75 |
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
@@ -77,7 +77,7 @@ class DDIMSampler(object):
|
|
77 |
num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
|
78 |
alphas_cumprod = self.model.alphas_cumprod
|
79 |
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
80 |
-
to_torch = lambda x: x.clone().detach().to(torch.float32).to(torch.device("
|
81 |
|
82 |
self.register_buffer('betas', to_torch(self.model.betas))
|
83 |
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
|
|
68 |
|
69 |
def register_buffer(self, name, attr):
|
70 |
if type(attr) == torch.Tensor:
|
71 |
+
if attr.device != torch.device("cpu"):
|
72 |
+
attr = attr.to(torch.device("cpu"))
|
73 |
setattr(self, name, attr)
|
74 |
|
75 |
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
|
|
77 |
num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
|
78 |
alphas_cumprod = self.model.alphas_cumprod
|
79 |
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
80 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(torch.device("cpu"))
|
81 |
|
82 |
self.register_buffer('betas', to_torch(self.model.betas))
|
83 |
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|