Linoy Tsaban commited on
Commit
2bd2671
1 Parent(s): 248a53d

Update tokenflow_pnp.py

Browse files
Files changed (1) hide show
  1. tokenflow_pnp.py +90 -31
tokenflow_pnp.py CHANGED
@@ -9,6 +9,7 @@ import torchvision.transforms as T
9
  import argparse
10
  from PIL import Image
11
  import yaml
 
12
  from tqdm import tqdm
13
  from transformers import logging
14
  from diffusers import DDIMScheduler, StableDiffusionPipeline
@@ -25,9 +26,9 @@ VAE_BATCH_SIZE = 10
25
  class TokenFlow(nn.Module):
26
  def __init__(self, config,
27
  pipe,
28
- frames=None,
29
- # latents = None,
30
- inverted_latents = None):
31
  super().__init__()
32
  self.config = config
33
  self.device = config["device"]
@@ -61,7 +62,16 @@ class TokenFlow(nn.Module):
61
  print('SD model loaded')
62
 
63
  # data
64
- self.frames, self.inverted_latents = frames, inverted_latents
 
 
 
 
 
 
 
 
 
65
  self.latents_path = self.get_latents_path()
66
 
67
  # load frames
@@ -120,15 +130,13 @@ class TokenFlow(nn.Module):
120
 
121
  def get_latents_path(self):
122
  read_from_files = self.frames is None
123
- # read_from_files = True
124
  if read_from_files:
125
  latents_path = os.path.join(self.config["latents_path"], f'sd_{self.config["sd_version"]}',
126
  Path(self.config["data_path"]).stem, f'steps_{self.config["n_inversion_steps"]}')
127
  latents_path = [x for x in glob.glob(f'{latents_path}/*') if '.' not in Path(x).name]
128
  n_frames = [int([x for x in latents_path[i].split('/') if 'nframes' in x][0].split('_')[1]) for i in range(len(latents_path))]
129
- print("n_frames", n_frames)
130
  latents_path = latents_path[np.argmax(n_frames)]
131
- print("latents_path", latents_path)
132
  self.config["n_frames"] = min(max(n_frames), self.config["n_frames"])
133
 
134
  else:
@@ -138,9 +146,8 @@ class TokenFlow(nn.Module):
138
  if self.config["n_frames"] % self.config["batch_size"] != 0:
139
  # make n_frames divisible by batch_size
140
  self.config["n_frames"] = self.config["n_frames"] - (self.config["n_frames"] % self.config["batch_size"])
141
- print("Number of frames: ", self.config["n_frames"])
142
  if read_from_files:
143
- print("YOOOOOOO", os.path.join(latents_path, 'latents'))
144
  return os.path.join(latents_path, 'latents')
145
  else:
146
  return None
@@ -206,37 +213,61 @@ class TokenFlow(nn.Module):
206
  # encode to latents
207
  latents = self.encode_imgs(frames, deterministic=True).to(torch.float16).to(self.device)
208
  # get noise
209
- eps = self.get_ddim_eps(latents, range(self.config["n_frames"])).to(torch.float16).to(self.device)
 
 
 
 
 
 
210
  if not read_from_files:
211
  return None, frames, latents, eps
212
  return paths, frames, latents, eps
213
 
214
  def get_ddim_eps(self, latent, indices):
215
  read_from_files = self.inverted_latents is None
216
- # read_from_files = True
217
  if read_from_files:
218
  noisest = max([int(x.split('_')[-1].split('.')[0]) for x in glob.glob(os.path.join(self.latents_path, f'noisy_latents_*.pt'))])
219
- print("noisets:", noisest)
220
- print("indecies:", indices)
221
  latents_path = os.path.join(self.latents_path, f'noisy_latents_{noisest}.pt')
222
  noisy_latent = torch.load(latents_path)[indices].to(self.device)
223
-
224
- # path = os.path.join('test_latents', f'noisy_latents_{noisest}.pt')
225
- # f_noisy_latent = torch.load(path)[indices].to(self.device)
226
- # print(f_noisy_latent==noisy_latent)
227
  else:
228
  noisest = max([int(key.split("_")[-1]) for key in self.inverted_latents.keys()])
229
- print("noisets:", noisest)
230
- print("indecies:", indices)
231
  noisy_latent = self.inverted_latents[f'noisy_latents_{noisest}'][indices]
232
 
233
  alpha_prod_T = self.scheduler.alphas_cumprod[noisest]
234
  mu_T, sigma_T = alpha_prod_T ** 0.5, (1 - alpha_prod_T) ** 0.5
235
  eps = (noisy_latent - mu_T * latent) / sigma_T
236
  return eps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
  @torch.no_grad()
239
- def denoise_step(self, x, t, indices):
240
  # register the time step and features in pnp injection modules
241
  read_files = self.inverted_latents is None
242
 
@@ -264,21 +295,31 @@ class TokenFlow(nn.Module):
264
  noise_pred = noise_pred_uncond + self.config["guidance_scale"] * (noise_pred_cond - noise_pred_uncond)
265
 
266
  # compute the denoising step with the reference model
267
- denoised_latent = self.scheduler.step(noise_pred, t, x)['prev_sample']
 
268
  return denoised_latent
269
 
270
  @torch.autocast(dtype=torch.float16, device_type='cuda')
271
- def batched_denoise_step(self, x, t, indices):
272
  batch_size = self.config["batch_size"]
273
  denoised_latents = []
274
- pivotal_idx = torch.randint(batch_size, (len(x)//batch_size,)) + torch.arange(0,len(x),batch_size)
275
-
276
  register_pivotal(self, True)
277
- self.denoise_step(x[pivotal_idx], t, indices[pivotal_idx])
 
 
 
 
278
  register_pivotal(self, False)
279
  for i, b in enumerate(range(0, len(x), batch_size)):
280
  register_batch_idx(self, i)
281
- denoised_latents.append(self.denoise_step(x[b:b + batch_size], t, indices[b:b + batch_size]))
 
 
 
 
 
282
  denoised_latents = torch.cat(denoised_latents)
283
  return denoised_latents
284
 
@@ -309,7 +350,13 @@ class TokenFlow(nn.Module):
309
 
310
  self.init_method(conv_injection_t=pnp_f_t, qk_injection_t=pnp_attn_t)
311
 
312
- noisy_latents = self.scheduler.add_noise(self.latents, self.eps, self.scheduler.timesteps[0])
 
 
 
 
 
 
313
  edited_frames = self.sample_loop(noisy_latents, torch.arange(self.config["n_frames"]))
314
 
315
  if save_files:
@@ -321,12 +368,24 @@ class TokenFlow(nn.Module):
321
  return edited_frames
322
 
323
  def sample_loop(self, x, indices):
324
- save_files = self.inverted_latents is None # if we're in the original non-demo setting
325
- # save_files = True
326
  if save_files:
327
  os.makedirs(f'{self.config["output_path"]}/img_ode', exist_ok=True)
328
- for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="Sampling")):
329
- x = self.batched_denoise_step(x, t, indices)
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
  decoded_latents = self.decode_latents(x)
332
  if save_files:
 
9
  import argparse
10
  from PIL import Image
11
  import yaml
12
+ import inspect
13
  from tqdm import tqdm
14
  from transformers import logging
15
  from diffusers import DDIMScheduler, StableDiffusionPipeline
 
26
  class TokenFlow(nn.Module):
27
  def __init__(self, config,
28
  pipe,
29
+ frames = None,
30
+ inverted_latents = None, #X0,...,XT,
31
+ zs = None):
32
  super().__init__()
33
  self.config = config
34
  self.device = config["device"]
 
62
  print('SD model loaded')
63
 
64
  # data
65
+ self.inversion = config['inversion']
66
+ if self.inversion == 'ddpm':
67
+ self.skip_steps = config['skip_steps']
68
+ self.eta = 1.0
69
+ else:
70
+ self.eta = 0.0
71
+ self.extra_step_kwargs = self.prepare_extra_step_kwargs(self.eta)
72
+
73
+ # data
74
+ self.frames, self.inverted_latents, self.zs = frames, inverted_latents, zs
75
  self.latents_path = self.get_latents_path()
76
 
77
  # load frames
 
130
 
131
  def get_latents_path(self):
132
  read_from_files = self.frames is None
 
133
  if read_from_files:
134
  latents_path = os.path.join(self.config["latents_path"], f'sd_{self.config["sd_version"]}',
135
  Path(self.config["data_path"]).stem, f'steps_{self.config["n_inversion_steps"]}')
136
  latents_path = [x for x in glob.glob(f'{latents_path}/*') if '.' not in Path(x).name]
137
  n_frames = [int([x for x in latents_path[i].split('/') if 'nframes' in x][0].split('_')[1]) for i in range(len(latents_path))]
 
138
  latents_path = latents_path[np.argmax(n_frames)]
139
+
140
  self.config["n_frames"] = min(max(n_frames), self.config["n_frames"])
141
 
142
  else:
 
146
  if self.config["n_frames"] % self.config["batch_size"] != 0:
147
  # make n_frames divisible by batch_size
148
  self.config["n_frames"] = self.config["n_frames"] - (self.config["n_frames"] % self.config["batch_size"])
149
+
150
  if read_from_files:
 
151
  return os.path.join(latents_path, 'latents')
152
  else:
153
  return None
 
213
  # encode to latents
214
  latents = self.encode_imgs(frames, deterministic=True).to(torch.float16).to(self.device)
215
  # get noise
216
+ if self.inversion == 'ddim':
217
+ eps = self.get_ddim_eps(latents, range(self.config["n_frames"])).to(torch.float16).to(self.device)
218
+ elif self.inversion == 'ddpm':
219
+ eps = self.get_ddpm_noise()
220
+ else:
221
+ raise NotImplementedError()
222
+
223
  if not read_from_files:
224
  return None, frames, latents, eps
225
  return paths, frames, latents, eps
226
 
227
  def get_ddim_eps(self, latent, indices):
228
  read_from_files = self.inverted_latents is None
 
229
  if read_from_files:
230
  noisest = max([int(x.split('_')[-1].split('.')[0]) for x in glob.glob(os.path.join(self.latents_path, f'noisy_latents_*.pt'))])
 
 
231
  latents_path = os.path.join(self.latents_path, f'noisy_latents_{noisest}.pt')
232
  noisy_latent = torch.load(latents_path)[indices].to(self.device)
 
 
 
 
233
  else:
234
  noisest = max([int(key.split("_")[-1]) for key in self.inverted_latents.keys()])
 
 
235
  noisy_latent = self.inverted_latents[f'noisy_latents_{noisest}'][indices]
236
 
237
  alpha_prod_T = self.scheduler.alphas_cumprod[noisest]
238
  mu_T, sigma_T = alpha_prod_T ** 0.5, (1 - alpha_prod_T) ** 0.5
239
  eps = (noisy_latent - mu_T * latent) / sigma_T
240
  return eps
241
+
242
+ def get_ddpm_noise(self):
243
+ read_from_files = self.inverted_latents is None
244
+ idx_to_t = {int(k): int(v) for k, v in enumerate(self.scheduler.timesteps)}
245
+ t = idx_to_t[self.skip_steps]
246
+ if read_from_files:
247
+ x0_path = os.path.join(self.latents_path, f'noisy_latents_{t}.pt')
248
+ zs_path = os.path.join(self.latents_path, f'noise_total.pt')
249
+ x0 = torch.load(x0_path)[:self.config["n_frames"]].to(self.device)
250
+ zs = torch.load(zs_path)[self.skip_steps:, :self.config["n_frames"]].to(self.device)
251
+ else:
252
+ x0 = self.inverted_latents[f'noisy_latents_{t}'][:self.config["n_frames"]].to(self.device)
253
+ zs = self.zs[self.skip_steps:, :self.config["n_frames"]].to(self.device)
254
+ return x0, zs
255
+
256
+ def prepare_extra_step_kwargs(self, eta):
257
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
258
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
259
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
260
+ # and should be between [0, 1]
261
+
262
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
263
+ extra_step_kwargs = {}
264
+ if accepts_eta:
265
+ extra_step_kwargs["eta"] = eta
266
+
267
+ return extra_step_kwargs
268
 
269
  @torch.no_grad()
270
+ def denoise_step(self, x, t, indices, zs=None):
271
  # register the time step and features in pnp injection modules
272
  read_files = self.inverted_latents is None
273
 
 
295
  noise_pred = noise_pred_uncond + self.config["guidance_scale"] * (noise_pred_cond - noise_pred_uncond)
296
 
297
  # compute the denoising step with the reference model
298
+ denoised_latent = self.scheduler.step(noise_pred, t, x, variance_noise=zs, **self.extra_step_kwargs)[
299
+ 'prev_sample']
300
  return denoised_latent
301
 
302
  @torch.autocast(dtype=torch.float16, device_type='cuda')
303
+ def batched_denoise_step(self, x, t, indices, zs=None):
304
  batch_size = self.config["batch_size"]
305
  denoised_latents = []
306
+ pivotal_idx = torch.randint(batch_size, (len(x) // batch_size,)) + torch.arange(0, len(x), batch_size)
307
+
308
  register_pivotal(self, True)
309
+ if zs is None:
310
+ zs_input = None
311
+ else:
312
+ zs_input = zs[pivotal_idx]
313
+ self.denoise_step(x[pivotal_idx], t, indices[pivotal_idx], zs_input)
314
  register_pivotal(self, False)
315
  for i, b in enumerate(range(0, len(x), batch_size)):
316
  register_batch_idx(self, i)
317
+ if zs is None:
318
+ zs_input = None
319
+ else:
320
+ zs_input = zs[b:b + batch_size]
321
+ denoised_latents.append(self.denoise_step(x[b:b + batch_size], t, indices[b:b + batch_size]
322
+ , zs_input))
323
  denoised_latents = torch.cat(denoised_latents)
324
  return denoised_latents
325
 
 
350
 
351
  self.init_method(conv_injection_t=pnp_f_t, qk_injection_t=pnp_attn_t)
352
 
353
+ if self.inversion == 'ddim':
354
+ noisy_latents = self.scheduler.add_noise(self.latents, self.eps, self.scheduler.timesteps[0])
355
+ elif self.inversion == 'ddpm':
356
+ noisy_latents = self.eps[0]
357
+ else:
358
+ raise NotImplementedError()
359
+
360
  edited_frames = self.sample_loop(noisy_latents, torch.arange(self.config["n_frames"]))
361
 
362
  if save_files:
 
368
  return edited_frames
369
 
370
  def sample_loop(self, x, indices):
371
+ save_files = self.inverted_latents is None # if we're in the original non-demo settinge
 
372
  if save_files:
373
  os.makedirs(f'{self.config["output_path"]}/img_ode', exist_ok=True)
374
+
375
+ timesteps = self.scheduler.timesteps
376
+ if self.inversion == 'ddpm':
377
+ zs_total = self.eps[1]
378
+
379
+ t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs_total.shape[0]:])}
380
+ timesteps = timesteps[-zs_total.shape[0]:]
381
+
382
+ for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
383
+ if self.inversion == 'ddpm':
384
+ idx = t_to_idx[int(t)]
385
+ zs = zs_total[idx]
386
+ else:
387
+ zs = None
388
+ x = self.batched_denoise_step(x, t, indices, zs)
389
 
390
  decoded_latents = self.decode_latents(x)
391
  if save_files: