RamAnanth1 commited on
Commit
45fcf4c
1 Parent(s): 9ae63f1

Update ldm/models/diffusion/ddpm.py

Browse files
Files changed (1) hide show
  1. ldm/models/diffusion/ddpm.py +1 -1
ldm/models/diffusion/ddpm.py CHANGED
@@ -208,7 +208,7 @@ class DDPM(pl.LightningModule):
208
 
209
  @torch.no_grad()
210
  def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
211
- sd = torch.load(path, map_location="cpu")
212
  if "state_dict" in list(sd.keys()):
213
  sd = sd["state_dict"]
214
  keys = list(sd.keys())
 
208
 
209
  @torch.no_grad()
210
  def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
211
+ sd = torch.load(path, map_location="cuda")
212
  if "state_dict" in list(sd.keys()):
213
  sd = sd["state_dict"]
214
  keys = list(sd.keys())