hyoungwoncho
commited on
Commit
•
30c94b4
1
Parent(s):
52b8353
Update pipeline.py
Browse files- pipeline.py +107 -127
pipeline.py
CHANGED
@@ -12,8 +12,11 @@ from diffusers.configuration_utils import FrozenDict
|
|
12 |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
13 |
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
14 |
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
15 |
-
from diffusers.models.attention_processor import FusedAttnProcessor2_0
|
16 |
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
|
|
|
|
|
|
17 |
from diffusers.schedulers import KarrasDiffusionSchedulers
|
18 |
from diffusers.utils import (
|
19 |
USE_PEFT_BACKEND,
|
@@ -24,11 +27,6 @@ from diffusers.utils import (
|
|
24 |
unscale_lora_layers,
|
25 |
)
|
26 |
from diffusers.utils.torch_utils import randn_tensor
|
27 |
-
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
28 |
-
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
29 |
-
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
30 |
-
|
31 |
-
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
32 |
|
33 |
|
34 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -68,7 +66,7 @@ class PAGIdentitySelfAttnProcessor:
|
|
68 |
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
69 |
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
70 |
deprecate("scale", "1.0.0", deprecation_message)
|
71 |
-
|
72 |
residual = hidden_states
|
73 |
if attn.spatial_norm is not None:
|
74 |
hidden_states = attn.spatial_norm(hidden_states, temb)
|
@@ -77,10 +75,10 @@ class PAGIdentitySelfAttnProcessor:
|
|
77 |
if input_ndim == 4:
|
78 |
batch_size, channel, height, width = hidden_states.shape
|
79 |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
80 |
-
|
81 |
# chunk
|
82 |
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
|
83 |
-
|
84 |
# original path
|
85 |
batch_size, sequence_length, _ = hidden_states_org.shape
|
86 |
|
@@ -113,7 +111,7 @@ class PAGIdentitySelfAttnProcessor:
|
|
113 |
|
114 |
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
115 |
hidden_states_org = hidden_states_org.to(query.dtype)
|
116 |
-
|
117 |
# linear proj
|
118 |
hidden_states_org = attn.to_out[0](hidden_states_org)
|
119 |
# dropout
|
@@ -135,12 +133,12 @@ class PAGIdentitySelfAttnProcessor:
|
|
135 |
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
|
136 |
|
137 |
value = attn.to_v(hidden_states_ptb)
|
138 |
-
|
139 |
hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
|
140 |
-
#hidden_states_ptb = value
|
141 |
-
|
142 |
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
|
143 |
-
|
144 |
# linear proj
|
145 |
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
146 |
# dropout
|
@@ -182,7 +180,7 @@ class PAGCFGIdentitySelfAttnProcessor:
|
|
182 |
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
183 |
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
184 |
deprecate("scale", "1.0.0", deprecation_message)
|
185 |
-
|
186 |
residual = hidden_states
|
187 |
if attn.spatial_norm is not None:
|
188 |
hidden_states = attn.spatial_norm(hidden_states, temb)
|
@@ -191,11 +189,11 @@ class PAGCFGIdentitySelfAttnProcessor:
|
|
191 |
if input_ndim == 4:
|
192 |
batch_size, channel, height, width = hidden_states.shape
|
193 |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
194 |
-
|
195 |
# chunk
|
196 |
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
|
197 |
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
|
198 |
-
|
199 |
# original path
|
200 |
batch_size, sequence_length, _ = hidden_states_org.shape
|
201 |
|
@@ -207,7 +205,7 @@ class PAGCFGIdentitySelfAttnProcessor:
|
|
207 |
|
208 |
if attn.group_norm is not None:
|
209 |
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
|
210 |
-
|
211 |
query = attn.to_q(hidden_states_org)
|
212 |
key = attn.to_k(hidden_states_org)
|
213 |
value = attn.to_v(hidden_states_org)
|
@@ -228,7 +226,7 @@ class PAGCFGIdentitySelfAttnProcessor:
|
|
228 |
|
229 |
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
230 |
hidden_states_org = hidden_states_org.to(query.dtype)
|
231 |
-
|
232 |
# linear proj
|
233 |
hidden_states_org = attn.to_out[0](hidden_states_org)
|
234 |
# dropout
|
@@ -252,7 +250,7 @@ class PAGCFGIdentitySelfAttnProcessor:
|
|
252 |
value = attn.to_v(hidden_states_ptb)
|
253 |
hidden_states_ptb = value
|
254 |
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
|
255 |
-
|
256 |
# linear proj
|
257 |
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
258 |
# dropout
|
@@ -328,7 +326,7 @@ def retrieve_timesteps(
|
|
328 |
return timesteps, num_inference_steps
|
329 |
|
330 |
|
331 |
-
class
|
332 |
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
|
333 |
):
|
334 |
r"""
|
@@ -976,7 +974,7 @@ class StableDiffusionPipeline(
|
|
976 |
emb = torch.nn.functional.pad(emb, (0, 1))
|
977 |
assert emb.shape == (w.shape[0], embedding_dim)
|
978 |
return emb
|
979 |
-
|
980 |
def pred_z0(self, sample, model_output, timestep):
|
981 |
alpha_prod_t = self.scheduler.alphas_cumprod[timestep].to(sample.device)
|
982 |
|
@@ -996,19 +994,14 @@ class StableDiffusionPipeline(
|
|
996 |
)
|
997 |
|
998 |
return pred_original_sample
|
999 |
-
|
1000 |
def pred_x0(self, latents, noise_pred, t, generator, device, prompt_embeds, output_type):
|
1001 |
-
|
1002 |
pred_z0 = self.pred_z0(latents, noise_pred, t)
|
1003 |
-
pred_x0 = self.vae.decode(
|
1004 |
-
pred_z0 / self.vae.config.scaling_factor,
|
1005 |
-
return_dict=False,
|
1006 |
-
generator=generator
|
1007 |
-
)[0]
|
1008 |
pred_x0, ____ = self.run_safety_checker(pred_x0, device, prompt_embeds.dtype)
|
1009 |
do_denormalize = [True] * pred_x0.shape[0]
|
1010 |
pred_x0 = self.image_processor.postprocess(pred_x0, output_type=output_type, do_denormalize=do_denormalize)
|
1011 |
-
|
1012 |
return pred_x0
|
1013 |
|
1014 |
@property
|
@@ -1041,36 +1034,27 @@ class StableDiffusionPipeline(
|
|
1041 |
@property
|
1042 |
def interrupt(self):
|
1043 |
return self._interrupt
|
1044 |
-
|
1045 |
@property
|
1046 |
def pag_scale(self):
|
1047 |
return self._pag_scale
|
1048 |
-
|
1049 |
@property
|
1050 |
-
def
|
1051 |
return self._pag_scale > 0
|
1052 |
-
|
1053 |
@property
|
1054 |
def pag_adaptive_scaling(self):
|
1055 |
return self._pag_adaptive_scaling
|
1056 |
-
|
1057 |
@property
|
1058 |
def do_pag_adaptive_scaling(self):
|
1059 |
return self._pag_adaptive_scaling > 0
|
1060 |
-
|
1061 |
-
@property
|
1062 |
-
def pag_drop_rate(self):
|
1063 |
-
return self._pag_drop_rate
|
1064 |
-
|
1065 |
-
@property
|
1066 |
-
def pag_applied_layers(self):
|
1067 |
-
return self._pag_applied_layers
|
1068 |
-
|
1069 |
@property
|
1070 |
def pag_applied_layers_index(self):
|
1071 |
return self._pag_applied_layers_index
|
1072 |
|
1073 |
-
|
1074 |
@torch.no_grad()
|
1075 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
1076 |
def __call__(
|
@@ -1083,9 +1067,7 @@ class StableDiffusionPipeline(
|
|
1083 |
guidance_scale: float = 7.5,
|
1084 |
pag_scale: float = 0.0,
|
1085 |
pag_adaptive_scaling: float = 0.0,
|
1086 |
-
|
1087 |
-
pag_applied_layers: List[str] = ['down'], #['down', 'mid', 'up']
|
1088 |
-
pag_applied_layers_index: List[str] = ['d4'], #['d4', 'd5', 'm0']
|
1089 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1090 |
num_images_per_prompt: Optional[int] = 1,
|
1091 |
eta: float = 0.0,
|
@@ -1221,11 +1203,9 @@ class StableDiffusionPipeline(
|
|
1221 |
self._clip_skip = clip_skip
|
1222 |
self._cross_attention_kwargs = cross_attention_kwargs
|
1223 |
self._interrupt = False
|
1224 |
-
|
1225 |
self._pag_scale = pag_scale
|
1226 |
self._pag_adaptive_scaling = pag_adaptive_scaling
|
1227 |
-
self._pag_drop_rate = pag_drop_rate
|
1228 |
-
self._pag_applied_layers = pag_applied_layers
|
1229 |
self._pag_applied_layers_index = pag_applied_layers_index
|
1230 |
|
1231 |
# 2. Define call parameters
|
@@ -1258,15 +1238,15 @@ class StableDiffusionPipeline(
|
|
1258 |
# For classifier free guidance, we need to do two forward passes.
|
1259 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
1260 |
# to avoid doing two forward passes
|
1261 |
-
|
1262 |
-
#cfg
|
1263 |
-
if self.do_classifier_free_guidance and not self.
|
1264 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
1265 |
-
#pag
|
1266 |
-
elif not self.do_classifier_free_guidance and self.
|
1267 |
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds])
|
1268 |
-
#both
|
1269 |
-
elif self.do_classifier_free_guidance and self.
|
1270 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds])
|
1271 |
|
1272 |
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
@@ -1309,21 +1289,44 @@ class StableDiffusionPipeline(
|
|
1309 |
).to(device=device, dtype=latents.dtype)
|
1310 |
|
1311 |
# 7. Denoising loop
|
1312 |
-
if self.
|
1313 |
down_layers = []
|
1314 |
mid_layers = []
|
1315 |
up_layers = []
|
1316 |
for name, module in self.unet.named_modules():
|
1317 |
-
if
|
1318 |
-
layer_type = name.split(
|
1319 |
-
if layer_type ==
|
1320 |
down_layers.append(module)
|
1321 |
-
elif layer_type ==
|
1322 |
mid_layers.append(module)
|
1323 |
-
elif layer_type ==
|
1324 |
up_layers.append(module)
|
1325 |
else:
|
1326 |
raise ValueError(f"Invalid layer type: {layer_type}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1327 |
|
1328 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1329 |
self._num_timesteps = len(timesteps)
|
@@ -1331,46 +1334,22 @@ class StableDiffusionPipeline(
|
|
1331 |
for i, t in enumerate(timesteps):
|
1332 |
if self.interrupt:
|
1333 |
continue
|
1334 |
-
|
1335 |
-
#cfg
|
1336 |
-
if self.do_classifier_free_guidance and not self.
|
1337 |
latent_model_input = torch.cat([latents] * 2)
|
1338 |
-
#pag
|
1339 |
-
elif not self.do_classifier_free_guidance and self.
|
1340 |
latent_model_input = torch.cat([latents] * 2)
|
1341 |
-
#both
|
1342 |
-
elif self.do_classifier_free_guidance and self.
|
1343 |
latent_model_input = torch.cat([latents] * 3)
|
1344 |
-
#no
|
1345 |
else:
|
1346 |
latent_model_input = latents
|
1347 |
-
|
1348 |
-
# change attention layer in UNet if use PAG
|
1349 |
-
if self.do_adversarial_guidance:
|
1350 |
-
|
1351 |
-
if self.do_classifier_free_guidance:
|
1352 |
-
replace_processor = PAGCFGIdentitySelfAttnProcessor()
|
1353 |
-
else:
|
1354 |
-
replace_processor = PAGIdentitySelfAttnProcessor()
|
1355 |
-
|
1356 |
-
drop_layers = self.pag_applied_layers_index
|
1357 |
-
for drop_layer in drop_layers:
|
1358 |
-
try:
|
1359 |
-
if drop_layer[0] == 'd':
|
1360 |
-
down_layers[int(drop_layer[1])].processor = replace_processor
|
1361 |
-
elif drop_layer[0] == 'm':
|
1362 |
-
mid_layers[int(drop_layer[1])].processor = replace_processor
|
1363 |
-
elif drop_layer[0] == 'u':
|
1364 |
-
up_layers[int(drop_layer[1])].processor = replace_processor
|
1365 |
-
else:
|
1366 |
-
raise ValueError(f"Invalid layer type: {drop_layer[0]}")
|
1367 |
-
except IndexError:
|
1368 |
-
raise ValueError(
|
1369 |
-
f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
|
1370 |
-
)
|
1371 |
-
|
1372 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1373 |
-
|
1374 |
# predict the noise residual
|
1375 |
noise_pred = self.unet(
|
1376 |
latent_model_input,
|
@@ -1381,43 +1360,44 @@ class StableDiffusionPipeline(
|
|
1381 |
added_cond_kwargs=added_cond_kwargs,
|
1382 |
return_dict=False,
|
1383 |
)[0]
|
1384 |
-
|
1385 |
# perform guidance
|
1386 |
-
|
1387 |
# cfg
|
1388 |
-
if self.do_classifier_free_guidance and not self.
|
1389 |
-
|
1390 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1391 |
-
|
1392 |
delta = noise_pred_text - noise_pred_uncond
|
1393 |
noise_pred = noise_pred_uncond + self.guidance_scale * delta
|
1394 |
-
|
1395 |
# pag
|
1396 |
-
elif not self.do_classifier_free_guidance and self.
|
1397 |
-
|
1398 |
noise_pred_original, noise_pred_perturb = noise_pred.chunk(2)
|
1399 |
-
|
1400 |
signal_scale = self.pag_scale
|
1401 |
if self.do_pag_adaptive_scaling:
|
1402 |
-
signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t)
|
1403 |
-
if signal_scale<0:
|
1404 |
signal_scale = 0
|
1405 |
-
|
1406 |
noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb)
|
1407 |
-
|
1408 |
# both
|
1409 |
-
elif self.do_classifier_free_guidance and self.
|
1410 |
-
|
1411 |
noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3)
|
1412 |
-
|
1413 |
signal_scale = self.pag_scale
|
1414 |
if self.do_pag_adaptive_scaling:
|
1415 |
-
signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t)
|
1416 |
-
if signal_scale<0:
|
1417 |
signal_scale = 0
|
1418 |
-
|
1419 |
-
noise_pred =
|
1420 |
-
|
|
|
|
|
|
|
|
|
1421 |
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
1422 |
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1423 |
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
@@ -1460,20 +1440,17 @@ class StableDiffusionPipeline(
|
|
1460 |
|
1461 |
# Offload all models
|
1462 |
self.maybe_free_model_hooks()
|
1463 |
-
|
1464 |
-
if not return_dict:
|
1465 |
-
return (image, has_nsfw_concept)
|
1466 |
-
|
1467 |
# change attention layer in UNet if use PAG
|
1468 |
-
if self.
|
1469 |
drop_layers = self.pag_applied_layers_index
|
1470 |
for drop_layer in drop_layers:
|
1471 |
try:
|
1472 |
-
if drop_layer[0] ==
|
1473 |
down_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
|
1474 |
-
elif drop_layer[0] ==
|
1475 |
mid_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
|
1476 |
-
elif drop_layer[0] ==
|
1477 |
up_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
|
1478 |
else:
|
1479 |
raise ValueError(f"Invalid layer type: {drop_layer[0]}")
|
@@ -1482,4 +1459,7 @@ class StableDiffusionPipeline(
|
|
1482 |
f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
|
1483 |
)
|
1484 |
|
|
|
|
|
|
|
1485 |
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
|
|
12 |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
13 |
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
14 |
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
15 |
+
from diffusers.models.attention_processor import Attention, AttnProcessor2_0, FusedAttnProcessor2_0
|
16 |
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
17 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
18 |
+
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
19 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
20 |
from diffusers.schedulers import KarrasDiffusionSchedulers
|
21 |
from diffusers.utils import (
|
22 |
USE_PEFT_BACKEND,
|
|
|
27 |
unscale_lora_layers,
|
28 |
)
|
29 |
from diffusers.utils.torch_utils import randn_tensor
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
|
32 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
66 |
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
67 |
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
68 |
deprecate("scale", "1.0.0", deprecation_message)
|
69 |
+
|
70 |
residual = hidden_states
|
71 |
if attn.spatial_norm is not None:
|
72 |
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
|
75 |
if input_ndim == 4:
|
76 |
batch_size, channel, height, width = hidden_states.shape
|
77 |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
78 |
+
|
79 |
# chunk
|
80 |
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
|
81 |
+
|
82 |
# original path
|
83 |
batch_size, sequence_length, _ = hidden_states_org.shape
|
84 |
|
|
|
111 |
|
112 |
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
113 |
hidden_states_org = hidden_states_org.to(query.dtype)
|
114 |
+
|
115 |
# linear proj
|
116 |
hidden_states_org = attn.to_out[0](hidden_states_org)
|
117 |
# dropout
|
|
|
133 |
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
|
134 |
|
135 |
value = attn.to_v(hidden_states_ptb)
|
136 |
+
|
137 |
hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
|
138 |
+
# hidden_states_ptb = value
|
139 |
+
|
140 |
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
|
141 |
+
|
142 |
# linear proj
|
143 |
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
144 |
# dropout
|
|
|
180 |
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
181 |
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
182 |
deprecate("scale", "1.0.0", deprecation_message)
|
183 |
+
|
184 |
residual = hidden_states
|
185 |
if attn.spatial_norm is not None:
|
186 |
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
|
189 |
if input_ndim == 4:
|
190 |
batch_size, channel, height, width = hidden_states.shape
|
191 |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
192 |
+
|
193 |
# chunk
|
194 |
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
|
195 |
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
|
196 |
+
|
197 |
# original path
|
198 |
batch_size, sequence_length, _ = hidden_states_org.shape
|
199 |
|
|
|
205 |
|
206 |
if attn.group_norm is not None:
|
207 |
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
|
208 |
+
|
209 |
query = attn.to_q(hidden_states_org)
|
210 |
key = attn.to_k(hidden_states_org)
|
211 |
value = attn.to_v(hidden_states_org)
|
|
|
226 |
|
227 |
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
228 |
hidden_states_org = hidden_states_org.to(query.dtype)
|
229 |
+
|
230 |
# linear proj
|
231 |
hidden_states_org = attn.to_out[0](hidden_states_org)
|
232 |
# dropout
|
|
|
250 |
value = attn.to_v(hidden_states_ptb)
|
251 |
hidden_states_ptb = value
|
252 |
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
|
253 |
+
|
254 |
# linear proj
|
255 |
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
256 |
# dropout
|
|
|
326 |
return timesteps, num_inference_steps
|
327 |
|
328 |
|
329 |
+
class StableDiffusionPAGPipeline(
|
330 |
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
|
331 |
):
|
332 |
r"""
|
|
|
974 |
emb = torch.nn.functional.pad(emb, (0, 1))
|
975 |
assert emb.shape == (w.shape[0], embedding_dim)
|
976 |
return emb
|
977 |
+
|
978 |
def pred_z0(self, sample, model_output, timestep):
|
979 |
alpha_prod_t = self.scheduler.alphas_cumprod[timestep].to(sample.device)
|
980 |
|
|
|
994 |
)
|
995 |
|
996 |
return pred_original_sample
|
997 |
+
|
998 |
def pred_x0(self, latents, noise_pred, t, generator, device, prompt_embeds, output_type):
|
|
|
999 |
pred_z0 = self.pred_z0(latents, noise_pred, t)
|
1000 |
+
pred_x0 = self.vae.decode(pred_z0 / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
|
|
|
|
|
|
|
|
|
1001 |
pred_x0, ____ = self.run_safety_checker(pred_x0, device, prompt_embeds.dtype)
|
1002 |
do_denormalize = [True] * pred_x0.shape[0]
|
1003 |
pred_x0 = self.image_processor.postprocess(pred_x0, output_type=output_type, do_denormalize=do_denormalize)
|
1004 |
+
|
1005 |
return pred_x0
|
1006 |
|
1007 |
@property
|
|
|
1034 |
@property
|
1035 |
def interrupt(self):
|
1036 |
return self._interrupt
|
1037 |
+
|
1038 |
@property
|
1039 |
def pag_scale(self):
|
1040 |
return self._pag_scale
|
1041 |
+
|
1042 |
@property
|
1043 |
+
def do_perturbed_attention_guidance(self):
|
1044 |
return self._pag_scale > 0
|
1045 |
+
|
1046 |
@property
|
1047 |
def pag_adaptive_scaling(self):
|
1048 |
return self._pag_adaptive_scaling
|
1049 |
+
|
1050 |
@property
|
1051 |
def do_pag_adaptive_scaling(self):
|
1052 |
return self._pag_adaptive_scaling > 0
|
1053 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1054 |
@property
|
1055 |
def pag_applied_layers_index(self):
|
1056 |
return self._pag_applied_layers_index
|
1057 |
|
|
|
1058 |
@torch.no_grad()
|
1059 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
1060 |
def __call__(
|
|
|
1067 |
guidance_scale: float = 7.5,
|
1068 |
pag_scale: float = 0.0,
|
1069 |
pag_adaptive_scaling: float = 0.0,
|
1070 |
+
pag_applied_layers_index: List[str] = ["d4"], # ['d4', 'd5', 'm0']
|
|
|
|
|
1071 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1072 |
num_images_per_prompt: Optional[int] = 1,
|
1073 |
eta: float = 0.0,
|
|
|
1203 |
self._clip_skip = clip_skip
|
1204 |
self._cross_attention_kwargs = cross_attention_kwargs
|
1205 |
self._interrupt = False
|
1206 |
+
|
1207 |
self._pag_scale = pag_scale
|
1208 |
self._pag_adaptive_scaling = pag_adaptive_scaling
|
|
|
|
|
1209 |
self._pag_applied_layers_index = pag_applied_layers_index
|
1210 |
|
1211 |
# 2. Define call parameters
|
|
|
1238 |
# For classifier free guidance, we need to do two forward passes.
|
1239 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
1240 |
# to avoid doing two forward passes
|
1241 |
+
|
1242 |
+
# cfg
|
1243 |
+
if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
|
1244 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
1245 |
+
# pag
|
1246 |
+
elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
|
1247 |
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds])
|
1248 |
+
# both
|
1249 |
+
elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
|
1250 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds])
|
1251 |
|
1252 |
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
|
|
1289 |
).to(device=device, dtype=latents.dtype)
|
1290 |
|
1291 |
# 7. Denoising loop
|
1292 |
+
if self.do_perturbed_attention_guidance:
|
1293 |
down_layers = []
|
1294 |
mid_layers = []
|
1295 |
up_layers = []
|
1296 |
for name, module in self.unet.named_modules():
|
1297 |
+
if "attn1" in name and "to" not in name:
|
1298 |
+
layer_type = name.split(".")[0].split("_")[0]
|
1299 |
+
if layer_type == "down":
|
1300 |
down_layers.append(module)
|
1301 |
+
elif layer_type == "mid":
|
1302 |
mid_layers.append(module)
|
1303 |
+
elif layer_type == "up":
|
1304 |
up_layers.append(module)
|
1305 |
else:
|
1306 |
raise ValueError(f"Invalid layer type: {layer_type}")
|
1307 |
+
|
1308 |
+
# change attention layer in UNet if use PAG
|
1309 |
+
if self.do_perturbed_attention_guidance:
|
1310 |
+
if self.do_classifier_free_guidance:
|
1311 |
+
replace_processor = PAGCFGIdentitySelfAttnProcessor()
|
1312 |
+
else:
|
1313 |
+
replace_processor = PAGIdentitySelfAttnProcessor()
|
1314 |
+
|
1315 |
+
drop_layers = self.pag_applied_layers_index
|
1316 |
+
for drop_layer in drop_layers:
|
1317 |
+
try:
|
1318 |
+
if drop_layer[0] == "d":
|
1319 |
+
down_layers[int(drop_layer[1])].processor = replace_processor
|
1320 |
+
elif drop_layer[0] == "m":
|
1321 |
+
mid_layers[int(drop_layer[1])].processor = replace_processor
|
1322 |
+
elif drop_layer[0] == "u":
|
1323 |
+
up_layers[int(drop_layer[1])].processor = replace_processor
|
1324 |
+
else:
|
1325 |
+
raise ValueError(f"Invalid layer type: {drop_layer[0]}")
|
1326 |
+
except IndexError:
|
1327 |
+
raise ValueError(
|
1328 |
+
f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
|
1329 |
+
)
|
1330 |
|
1331 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1332 |
self._num_timesteps = len(timesteps)
|
|
|
1334 |
for i, t in enumerate(timesteps):
|
1335 |
if self.interrupt:
|
1336 |
continue
|
1337 |
+
|
1338 |
+
# cfg
|
1339 |
+
if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
|
1340 |
latent_model_input = torch.cat([latents] * 2)
|
1341 |
+
# pag
|
1342 |
+
elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
|
1343 |
latent_model_input = torch.cat([latents] * 2)
|
1344 |
+
# both
|
1345 |
+
elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
|
1346 |
latent_model_input = torch.cat([latents] * 3)
|
1347 |
+
# no
|
1348 |
else:
|
1349 |
latent_model_input = latents
|
1350 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1351 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1352 |
+
|
1353 |
# predict the noise residual
|
1354 |
noise_pred = self.unet(
|
1355 |
latent_model_input,
|
|
|
1360 |
added_cond_kwargs=added_cond_kwargs,
|
1361 |
return_dict=False,
|
1362 |
)[0]
|
1363 |
+
|
1364 |
# perform guidance
|
1365 |
+
|
1366 |
# cfg
|
1367 |
+
if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
|
|
|
1368 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1369 |
+
|
1370 |
delta = noise_pred_text - noise_pred_uncond
|
1371 |
noise_pred = noise_pred_uncond + self.guidance_scale * delta
|
1372 |
+
|
1373 |
# pag
|
1374 |
+
elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
|
|
|
1375 |
noise_pred_original, noise_pred_perturb = noise_pred.chunk(2)
|
1376 |
+
|
1377 |
signal_scale = self.pag_scale
|
1378 |
if self.do_pag_adaptive_scaling:
|
1379 |
+
signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000 - t)
|
1380 |
+
if signal_scale < 0:
|
1381 |
signal_scale = 0
|
1382 |
+
|
1383 |
noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb)
|
1384 |
+
|
1385 |
# both
|
1386 |
+
elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
|
|
|
1387 |
noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3)
|
1388 |
+
|
1389 |
signal_scale = self.pag_scale
|
1390 |
if self.do_pag_adaptive_scaling:
|
1391 |
+
signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000 - t)
|
1392 |
+
if signal_scale < 0:
|
1393 |
signal_scale = 0
|
1394 |
+
|
1395 |
+
noise_pred = (
|
1396 |
+
noise_pred_text
|
1397 |
+
+ (self.guidance_scale - 1.0) * (noise_pred_text - noise_pred_uncond)
|
1398 |
+
+ signal_scale * (noise_pred_text - noise_pred_text_perturb)
|
1399 |
+
)
|
1400 |
+
|
1401 |
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
1402 |
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1403 |
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
|
|
1440 |
|
1441 |
# Offload all models
|
1442 |
self.maybe_free_model_hooks()
|
1443 |
+
|
|
|
|
|
|
|
1444 |
# change attention layer in UNet if use PAG
|
1445 |
+
if self.do_perturbed_attention_guidance:
|
1446 |
drop_layers = self.pag_applied_layers_index
|
1447 |
for drop_layer in drop_layers:
|
1448 |
try:
|
1449 |
+
if drop_layer[0] == "d":
|
1450 |
down_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
|
1451 |
+
elif drop_layer[0] == "m":
|
1452 |
mid_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
|
1453 |
+
elif drop_layer[0] == "u":
|
1454 |
up_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
|
1455 |
else:
|
1456 |
raise ValueError(f"Invalid layer type: {drop_layer[0]}")
|
|
|
1459 |
f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
|
1460 |
)
|
1461 |
|
1462 |
+
if not return_dict:
|
1463 |
+
return (image, has_nsfw_concept)
|
1464 |
+
|
1465 |
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|