hyoungwoncho commited on
Commit
30c94b4
1 Parent(s): 52b8353

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +107 -127
pipeline.py CHANGED
@@ -12,8 +12,11 @@ from diffusers.configuration_utils import FrozenDict
12
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
13
  from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
14
  from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
15
- from diffusers.models.attention_processor import FusedAttnProcessor2_0
16
  from diffusers.models.lora import adjust_lora_scale_text_encoder
 
 
 
17
  from diffusers.schedulers import KarrasDiffusionSchedulers
18
  from diffusers.utils import (
19
  USE_PEFT_BACKEND,
@@ -24,11 +27,6 @@ from diffusers.utils import (
24
  unscale_lora_layers,
25
  )
26
  from diffusers.utils.torch_utils import randn_tensor
27
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
28
- from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
29
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
30
-
31
- from diffusers.models.attention_processor import Attention, AttnProcessor2_0
32
 
33
 
34
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -68,7 +66,7 @@ class PAGIdentitySelfAttnProcessor:
68
  if len(args) > 0 or kwargs.get("scale", None) is not None:
69
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
70
  deprecate("scale", "1.0.0", deprecation_message)
71
-
72
  residual = hidden_states
73
  if attn.spatial_norm is not None:
74
  hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -77,10 +75,10 @@ class PAGIdentitySelfAttnProcessor:
77
  if input_ndim == 4:
78
  batch_size, channel, height, width = hidden_states.shape
79
  hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
80
-
81
  # chunk
82
  hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
83
-
84
  # original path
85
  batch_size, sequence_length, _ = hidden_states_org.shape
86
 
@@ -113,7 +111,7 @@ class PAGIdentitySelfAttnProcessor:
113
 
114
  hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
115
  hidden_states_org = hidden_states_org.to(query.dtype)
116
-
117
  # linear proj
118
  hidden_states_org = attn.to_out[0](hidden_states_org)
119
  # dropout
@@ -135,12 +133,12 @@ class PAGIdentitySelfAttnProcessor:
135
  hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
136
 
137
  value = attn.to_v(hidden_states_ptb)
138
-
139
  hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
140
- #hidden_states_ptb = value
141
-
142
  hidden_states_ptb = hidden_states_ptb.to(query.dtype)
143
-
144
  # linear proj
145
  hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
146
  # dropout
@@ -182,7 +180,7 @@ class PAGCFGIdentitySelfAttnProcessor:
182
  if len(args) > 0 or kwargs.get("scale", None) is not None:
183
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
184
  deprecate("scale", "1.0.0", deprecation_message)
185
-
186
  residual = hidden_states
187
  if attn.spatial_norm is not None:
188
  hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -191,11 +189,11 @@ class PAGCFGIdentitySelfAttnProcessor:
191
  if input_ndim == 4:
192
  batch_size, channel, height, width = hidden_states.shape
193
  hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
194
-
195
  # chunk
196
  hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
197
  hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
198
-
199
  # original path
200
  batch_size, sequence_length, _ = hidden_states_org.shape
201
 
@@ -207,7 +205,7 @@ class PAGCFGIdentitySelfAttnProcessor:
207
 
208
  if attn.group_norm is not None:
209
  hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
210
-
211
  query = attn.to_q(hidden_states_org)
212
  key = attn.to_k(hidden_states_org)
213
  value = attn.to_v(hidden_states_org)
@@ -228,7 +226,7 @@ class PAGCFGIdentitySelfAttnProcessor:
228
 
229
  hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
230
  hidden_states_org = hidden_states_org.to(query.dtype)
231
-
232
  # linear proj
233
  hidden_states_org = attn.to_out[0](hidden_states_org)
234
  # dropout
@@ -252,7 +250,7 @@ class PAGCFGIdentitySelfAttnProcessor:
252
  value = attn.to_v(hidden_states_ptb)
253
  hidden_states_ptb = value
254
  hidden_states_ptb = hidden_states_ptb.to(query.dtype)
255
-
256
  # linear proj
257
  hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
258
  # dropout
@@ -328,7 +326,7 @@ def retrieve_timesteps(
328
  return timesteps, num_inference_steps
329
 
330
 
331
- class StableDiffusionPipeline(
332
  DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
333
  ):
334
  r"""
@@ -976,7 +974,7 @@ class StableDiffusionPipeline(
976
  emb = torch.nn.functional.pad(emb, (0, 1))
977
  assert emb.shape == (w.shape[0], embedding_dim)
978
  return emb
979
-
980
  def pred_z0(self, sample, model_output, timestep):
981
  alpha_prod_t = self.scheduler.alphas_cumprod[timestep].to(sample.device)
982
 
@@ -996,19 +994,14 @@ class StableDiffusionPipeline(
996
  )
997
 
998
  return pred_original_sample
999
-
1000
  def pred_x0(self, latents, noise_pred, t, generator, device, prompt_embeds, output_type):
1001
-
1002
  pred_z0 = self.pred_z0(latents, noise_pred, t)
1003
- pred_x0 = self.vae.decode(
1004
- pred_z0 / self.vae.config.scaling_factor,
1005
- return_dict=False,
1006
- generator=generator
1007
- )[0]
1008
  pred_x0, ____ = self.run_safety_checker(pred_x0, device, prompt_embeds.dtype)
1009
  do_denormalize = [True] * pred_x0.shape[0]
1010
  pred_x0 = self.image_processor.postprocess(pred_x0, output_type=output_type, do_denormalize=do_denormalize)
1011
-
1012
  return pred_x0
1013
 
1014
  @property
@@ -1041,36 +1034,27 @@ class StableDiffusionPipeline(
1041
  @property
1042
  def interrupt(self):
1043
  return self._interrupt
1044
-
1045
  @property
1046
  def pag_scale(self):
1047
  return self._pag_scale
1048
-
1049
  @property
1050
- def do_adversarial_guidance(self):
1051
  return self._pag_scale > 0
1052
-
1053
  @property
1054
  def pag_adaptive_scaling(self):
1055
  return self._pag_adaptive_scaling
1056
-
1057
  @property
1058
  def do_pag_adaptive_scaling(self):
1059
  return self._pag_adaptive_scaling > 0
1060
-
1061
- @property
1062
- def pag_drop_rate(self):
1063
- return self._pag_drop_rate
1064
-
1065
- @property
1066
- def pag_applied_layers(self):
1067
- return self._pag_applied_layers
1068
-
1069
  @property
1070
  def pag_applied_layers_index(self):
1071
  return self._pag_applied_layers_index
1072
 
1073
-
1074
  @torch.no_grad()
1075
  @replace_example_docstring(EXAMPLE_DOC_STRING)
1076
  def __call__(
@@ -1083,9 +1067,7 @@ class StableDiffusionPipeline(
1083
  guidance_scale: float = 7.5,
1084
  pag_scale: float = 0.0,
1085
  pag_adaptive_scaling: float = 0.0,
1086
- pag_drop_rate: float = 0.5,
1087
- pag_applied_layers: List[str] = ['down'], #['down', 'mid', 'up']
1088
- pag_applied_layers_index: List[str] = ['d4'], #['d4', 'd5', 'm0']
1089
  negative_prompt: Optional[Union[str, List[str]]] = None,
1090
  num_images_per_prompt: Optional[int] = 1,
1091
  eta: float = 0.0,
@@ -1221,11 +1203,9 @@ class StableDiffusionPipeline(
1221
  self._clip_skip = clip_skip
1222
  self._cross_attention_kwargs = cross_attention_kwargs
1223
  self._interrupt = False
1224
-
1225
  self._pag_scale = pag_scale
1226
  self._pag_adaptive_scaling = pag_adaptive_scaling
1227
- self._pag_drop_rate = pag_drop_rate
1228
- self._pag_applied_layers = pag_applied_layers
1229
  self._pag_applied_layers_index = pag_applied_layers_index
1230
 
1231
  # 2. Define call parameters
@@ -1258,15 +1238,15 @@ class StableDiffusionPipeline(
1258
  # For classifier free guidance, we need to do two forward passes.
1259
  # Here we concatenate the unconditional and text embeddings into a single batch
1260
  # to avoid doing two forward passes
1261
-
1262
- #cfg
1263
- if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
1264
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1265
- #pag
1266
- elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
1267
  prompt_embeds = torch.cat([prompt_embeds, prompt_embeds])
1268
- #both
1269
- elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
1270
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds])
1271
 
1272
  if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
@@ -1309,21 +1289,44 @@ class StableDiffusionPipeline(
1309
  ).to(device=device, dtype=latents.dtype)
1310
 
1311
  # 7. Denoising loop
1312
- if self.do_adversarial_guidance:
1313
  down_layers = []
1314
  mid_layers = []
1315
  up_layers = []
1316
  for name, module in self.unet.named_modules():
1317
- if 'attn1' in name and 'to' not in name:
1318
- layer_type = name.split('.')[0].split('_')[0]
1319
- if layer_type == 'down':
1320
  down_layers.append(module)
1321
- elif layer_type == 'mid':
1322
  mid_layers.append(module)
1323
- elif layer_type == 'up':
1324
  up_layers.append(module)
1325
  else:
1326
  raise ValueError(f"Invalid layer type: {layer_type}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1327
 
1328
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1329
  self._num_timesteps = len(timesteps)
@@ -1331,46 +1334,22 @@ class StableDiffusionPipeline(
1331
  for i, t in enumerate(timesteps):
1332
  if self.interrupt:
1333
  continue
1334
-
1335
- #cfg
1336
- if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
1337
  latent_model_input = torch.cat([latents] * 2)
1338
- #pag
1339
- elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
1340
  latent_model_input = torch.cat([latents] * 2)
1341
- #both
1342
- elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
1343
  latent_model_input = torch.cat([latents] * 3)
1344
- #no
1345
  else:
1346
  latent_model_input = latents
1347
-
1348
- # change attention layer in UNet if use PAG
1349
- if self.do_adversarial_guidance:
1350
-
1351
- if self.do_classifier_free_guidance:
1352
- replace_processor = PAGCFGIdentitySelfAttnProcessor()
1353
- else:
1354
- replace_processor = PAGIdentitySelfAttnProcessor()
1355
-
1356
- drop_layers = self.pag_applied_layers_index
1357
- for drop_layer in drop_layers:
1358
- try:
1359
- if drop_layer[0] == 'd':
1360
- down_layers[int(drop_layer[1])].processor = replace_processor
1361
- elif drop_layer[0] == 'm':
1362
- mid_layers[int(drop_layer[1])].processor = replace_processor
1363
- elif drop_layer[0] == 'u':
1364
- up_layers[int(drop_layer[1])].processor = replace_processor
1365
- else:
1366
- raise ValueError(f"Invalid layer type: {drop_layer[0]}")
1367
- except IndexError:
1368
- raise ValueError(
1369
- f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
1370
- )
1371
-
1372
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1373
-
1374
  # predict the noise residual
1375
  noise_pred = self.unet(
1376
  latent_model_input,
@@ -1381,43 +1360,44 @@ class StableDiffusionPipeline(
1381
  added_cond_kwargs=added_cond_kwargs,
1382
  return_dict=False,
1383
  )[0]
1384
-
1385
  # perform guidance
1386
-
1387
  # cfg
1388
- if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
1389
-
1390
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1391
-
1392
  delta = noise_pred_text - noise_pred_uncond
1393
  noise_pred = noise_pred_uncond + self.guidance_scale * delta
1394
-
1395
  # pag
1396
- elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
1397
-
1398
  noise_pred_original, noise_pred_perturb = noise_pred.chunk(2)
1399
-
1400
  signal_scale = self.pag_scale
1401
  if self.do_pag_adaptive_scaling:
1402
- signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t)
1403
- if signal_scale<0:
1404
  signal_scale = 0
1405
-
1406
  noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb)
1407
-
1408
  # both
1409
- elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
1410
-
1411
  noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3)
1412
-
1413
  signal_scale = self.pag_scale
1414
  if self.do_pag_adaptive_scaling:
1415
- signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t)
1416
- if signal_scale<0:
1417
  signal_scale = 0
1418
-
1419
- noise_pred = noise_pred_text + (self.guidance_scale-1.0) * (noise_pred_text - noise_pred_uncond) + signal_scale * (noise_pred_text - noise_pred_text_perturb)
1420
-
 
 
 
 
1421
  if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1422
  # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1423
  noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
@@ -1460,20 +1440,17 @@ class StableDiffusionPipeline(
1460
 
1461
  # Offload all models
1462
  self.maybe_free_model_hooks()
1463
-
1464
- if not return_dict:
1465
- return (image, has_nsfw_concept)
1466
-
1467
  # change attention layer in UNet if use PAG
1468
- if self.do_adversarial_guidance:
1469
  drop_layers = self.pag_applied_layers_index
1470
  for drop_layer in drop_layers:
1471
  try:
1472
- if drop_layer[0] == 'd':
1473
  down_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
1474
- elif drop_layer[0] == 'm':
1475
  mid_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
1476
- elif drop_layer[0] == 'u':
1477
  up_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
1478
  else:
1479
  raise ValueError(f"Invalid layer type: {drop_layer[0]}")
@@ -1482,4 +1459,7 @@ class StableDiffusionPipeline(
1482
  f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
1483
  )
1484
 
 
 
 
1485
  return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
12
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
13
  from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
14
  from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
15
+ from diffusers.models.attention_processor import Attention, AttnProcessor2_0, FusedAttnProcessor2_0
16
  from diffusers.models.lora import adjust_lora_scale_text_encoder
17
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
18
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
19
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
20
  from diffusers.schedulers import KarrasDiffusionSchedulers
21
  from diffusers.utils import (
22
  USE_PEFT_BACKEND,
 
27
  unscale_lora_layers,
28
  )
29
  from diffusers.utils.torch_utils import randn_tensor
 
 
 
 
 
30
 
31
 
32
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
 
66
  if len(args) > 0 or kwargs.get("scale", None) is not None:
67
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
68
  deprecate("scale", "1.0.0", deprecation_message)
69
+
70
  residual = hidden_states
71
  if attn.spatial_norm is not None:
72
  hidden_states = attn.spatial_norm(hidden_states, temb)
 
75
  if input_ndim == 4:
76
  batch_size, channel, height, width = hidden_states.shape
77
  hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
78
+
79
  # chunk
80
  hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
81
+
82
  # original path
83
  batch_size, sequence_length, _ = hidden_states_org.shape
84
 
 
111
 
112
  hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
113
  hidden_states_org = hidden_states_org.to(query.dtype)
114
+
115
  # linear proj
116
  hidden_states_org = attn.to_out[0](hidden_states_org)
117
  # dropout
 
133
  hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
134
 
135
  value = attn.to_v(hidden_states_ptb)
136
+
137
  hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
138
+ # hidden_states_ptb = value
139
+
140
  hidden_states_ptb = hidden_states_ptb.to(query.dtype)
141
+
142
  # linear proj
143
  hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
144
  # dropout
 
180
  if len(args) > 0 or kwargs.get("scale", None) is not None:
181
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
182
  deprecate("scale", "1.0.0", deprecation_message)
183
+
184
  residual = hidden_states
185
  if attn.spatial_norm is not None:
186
  hidden_states = attn.spatial_norm(hidden_states, temb)
 
189
  if input_ndim == 4:
190
  batch_size, channel, height, width = hidden_states.shape
191
  hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
192
+
193
  # chunk
194
  hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
195
  hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
196
+
197
  # original path
198
  batch_size, sequence_length, _ = hidden_states_org.shape
199
 
 
205
 
206
  if attn.group_norm is not None:
207
  hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
208
+
209
  query = attn.to_q(hidden_states_org)
210
  key = attn.to_k(hidden_states_org)
211
  value = attn.to_v(hidden_states_org)
 
226
 
227
  hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
228
  hidden_states_org = hidden_states_org.to(query.dtype)
229
+
230
  # linear proj
231
  hidden_states_org = attn.to_out[0](hidden_states_org)
232
  # dropout
 
250
  value = attn.to_v(hidden_states_ptb)
251
  hidden_states_ptb = value
252
  hidden_states_ptb = hidden_states_ptb.to(query.dtype)
253
+
254
  # linear proj
255
  hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
256
  # dropout
 
326
  return timesteps, num_inference_steps
327
 
328
 
329
+ class StableDiffusionPAGPipeline(
330
  DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
331
  ):
332
  r"""
 
974
  emb = torch.nn.functional.pad(emb, (0, 1))
975
  assert emb.shape == (w.shape[0], embedding_dim)
976
  return emb
977
+
978
  def pred_z0(self, sample, model_output, timestep):
979
  alpha_prod_t = self.scheduler.alphas_cumprod[timestep].to(sample.device)
980
 
 
994
  )
995
 
996
  return pred_original_sample
997
+
998
  def pred_x0(self, latents, noise_pred, t, generator, device, prompt_embeds, output_type):
 
999
  pred_z0 = self.pred_z0(latents, noise_pred, t)
1000
+ pred_x0 = self.vae.decode(pred_z0 / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
 
 
 
 
1001
  pred_x0, ____ = self.run_safety_checker(pred_x0, device, prompt_embeds.dtype)
1002
  do_denormalize = [True] * pred_x0.shape[0]
1003
  pred_x0 = self.image_processor.postprocess(pred_x0, output_type=output_type, do_denormalize=do_denormalize)
1004
+
1005
  return pred_x0
1006
 
1007
  @property
 
1034
  @property
1035
  def interrupt(self):
1036
  return self._interrupt
1037
+
1038
  @property
1039
  def pag_scale(self):
1040
  return self._pag_scale
1041
+
1042
  @property
1043
+ def do_perturbed_attention_guidance(self):
1044
  return self._pag_scale > 0
1045
+
1046
  @property
1047
  def pag_adaptive_scaling(self):
1048
  return self._pag_adaptive_scaling
1049
+
1050
  @property
1051
  def do_pag_adaptive_scaling(self):
1052
  return self._pag_adaptive_scaling > 0
1053
+
 
 
 
 
 
 
 
 
1054
  @property
1055
  def pag_applied_layers_index(self):
1056
  return self._pag_applied_layers_index
1057
 
 
1058
  @torch.no_grad()
1059
  @replace_example_docstring(EXAMPLE_DOC_STRING)
1060
  def __call__(
 
1067
  guidance_scale: float = 7.5,
1068
  pag_scale: float = 0.0,
1069
  pag_adaptive_scaling: float = 0.0,
1070
+ pag_applied_layers_index: List[str] = ["d4"], # ['d4', 'd5', 'm0']
 
 
1071
  negative_prompt: Optional[Union[str, List[str]]] = None,
1072
  num_images_per_prompt: Optional[int] = 1,
1073
  eta: float = 0.0,
 
1203
  self._clip_skip = clip_skip
1204
  self._cross_attention_kwargs = cross_attention_kwargs
1205
  self._interrupt = False
1206
+
1207
  self._pag_scale = pag_scale
1208
  self._pag_adaptive_scaling = pag_adaptive_scaling
 
 
1209
  self._pag_applied_layers_index = pag_applied_layers_index
1210
 
1211
  # 2. Define call parameters
 
1238
  # For classifier free guidance, we need to do two forward passes.
1239
  # Here we concatenate the unconditional and text embeddings into a single batch
1240
  # to avoid doing two forward passes
1241
+
1242
+ # cfg
1243
+ if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
1244
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1245
+ # pag
1246
+ elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
1247
  prompt_embeds = torch.cat([prompt_embeds, prompt_embeds])
1248
+ # both
1249
+ elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
1250
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds])
1251
 
1252
  if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
 
1289
  ).to(device=device, dtype=latents.dtype)
1290
 
1291
  # 7. Denoising loop
1292
+ if self.do_perturbed_attention_guidance:
1293
  down_layers = []
1294
  mid_layers = []
1295
  up_layers = []
1296
  for name, module in self.unet.named_modules():
1297
+ if "attn1" in name and "to" not in name:
1298
+ layer_type = name.split(".")[0].split("_")[0]
1299
+ if layer_type == "down":
1300
  down_layers.append(module)
1301
+ elif layer_type == "mid":
1302
  mid_layers.append(module)
1303
+ elif layer_type == "up":
1304
  up_layers.append(module)
1305
  else:
1306
  raise ValueError(f"Invalid layer type: {layer_type}")
1307
+
1308
+ # change attention layer in UNet if use PAG
1309
+ if self.do_perturbed_attention_guidance:
1310
+ if self.do_classifier_free_guidance:
1311
+ replace_processor = PAGCFGIdentitySelfAttnProcessor()
1312
+ else:
1313
+ replace_processor = PAGIdentitySelfAttnProcessor()
1314
+
1315
+ drop_layers = self.pag_applied_layers_index
1316
+ for drop_layer in drop_layers:
1317
+ try:
1318
+ if drop_layer[0] == "d":
1319
+ down_layers[int(drop_layer[1])].processor = replace_processor
1320
+ elif drop_layer[0] == "m":
1321
+ mid_layers[int(drop_layer[1])].processor = replace_processor
1322
+ elif drop_layer[0] == "u":
1323
+ up_layers[int(drop_layer[1])].processor = replace_processor
1324
+ else:
1325
+ raise ValueError(f"Invalid layer type: {drop_layer[0]}")
1326
+ except IndexError:
1327
+ raise ValueError(
1328
+ f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
1329
+ )
1330
 
1331
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1332
  self._num_timesteps = len(timesteps)
 
1334
  for i, t in enumerate(timesteps):
1335
  if self.interrupt:
1336
  continue
1337
+
1338
+ # cfg
1339
+ if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
1340
  latent_model_input = torch.cat([latents] * 2)
1341
+ # pag
1342
+ elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
1343
  latent_model_input = torch.cat([latents] * 2)
1344
+ # both
1345
+ elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
1346
  latent_model_input = torch.cat([latents] * 3)
1347
+ # no
1348
  else:
1349
  latent_model_input = latents
1350
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1351
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1352
+
1353
  # predict the noise residual
1354
  noise_pred = self.unet(
1355
  latent_model_input,
 
1360
  added_cond_kwargs=added_cond_kwargs,
1361
  return_dict=False,
1362
  )[0]
1363
+
1364
  # perform guidance
1365
+
1366
  # cfg
1367
+ if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
 
1368
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1369
+
1370
  delta = noise_pred_text - noise_pred_uncond
1371
  noise_pred = noise_pred_uncond + self.guidance_scale * delta
1372
+
1373
  # pag
1374
+ elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
 
1375
  noise_pred_original, noise_pred_perturb = noise_pred.chunk(2)
1376
+
1377
  signal_scale = self.pag_scale
1378
  if self.do_pag_adaptive_scaling:
1379
+ signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000 - t)
1380
+ if signal_scale < 0:
1381
  signal_scale = 0
1382
+
1383
  noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb)
1384
+
1385
  # both
1386
+ elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
 
1387
  noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3)
1388
+
1389
  signal_scale = self.pag_scale
1390
  if self.do_pag_adaptive_scaling:
1391
+ signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000 - t)
1392
+ if signal_scale < 0:
1393
  signal_scale = 0
1394
+
1395
+ noise_pred = (
1396
+ noise_pred_text
1397
+ + (self.guidance_scale - 1.0) * (noise_pred_text - noise_pred_uncond)
1398
+ + signal_scale * (noise_pred_text - noise_pred_text_perturb)
1399
+ )
1400
+
1401
  if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1402
  # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1403
  noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
 
1440
 
1441
  # Offload all models
1442
  self.maybe_free_model_hooks()
1443
+
 
 
 
1444
  # change attention layer in UNet if use PAG
1445
+ if self.do_perturbed_attention_guidance:
1446
  drop_layers = self.pag_applied_layers_index
1447
  for drop_layer in drop_layers:
1448
  try:
1449
+ if drop_layer[0] == "d":
1450
  down_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
1451
+ elif drop_layer[0] == "m":
1452
  mid_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
1453
+ elif drop_layer[0] == "u":
1454
  up_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
1455
  else:
1456
  raise ValueError(f"Invalid layer type: {drop_layer[0]}")
 
1459
  f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
1460
  )
1461
 
1462
+ if not return_dict:
1463
+ return (image, has_nsfw_concept)
1464
+
1465
  return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)