AI-Anchorite commited on
Commit
4e18f8a
1 Parent(s): abe2421

Delete allegro/models/transformers/block.py

Browse files
Files changed (1) hide show
  1. allegro/models/transformers/block.py +0 -1195
allegro/models/transformers/block.py DELETED
@@ -1,1195 +0,0 @@
1
- # Adapted from Open-Sora-Plan
2
-
3
- # This source code is licensed under the license found in the
4
- # LICENSE file in the root directory of this source tree.
5
- # --------------------------------------------------------
6
- # References:
7
- # Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
8
- # --------------------------------------------------------
9
-
10
-
11
- from importlib import import_module
12
- from typing import Any, Callable, Dict, Optional, Tuple
13
-
14
- import numpy as np
15
- import torch
16
- import collections
17
- import torch.nn.functional as F
18
- from torch.nn.attention import SDPBackend, sdpa_kernel
19
- from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
20
- from diffusers.models.attention_processor import (
21
- AttnAddedKVProcessor,
22
- AttnAddedKVProcessor2_0,
23
- AttnProcessor,
24
- CustomDiffusionAttnProcessor,
25
- CustomDiffusionAttnProcessor2_0,
26
- CustomDiffusionXFormersAttnProcessor,
27
- LoRAAttnAddedKVProcessor,
28
- LoRAAttnProcessor,
29
- LoRAAttnProcessor2_0,
30
- LoRAXFormersAttnProcessor,
31
- SlicedAttnAddedKVProcessor,
32
- SlicedAttnProcessor,
33
- SpatialNorm,
34
- XFormersAttnAddedKVProcessor,
35
- XFormersAttnProcessor,
36
- )
37
- from diffusers.models.embeddings import SinusoidalPositionalEmbedding
38
- from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
39
- from diffusers.utils import USE_PEFT_BACKEND, deprecate, is_xformers_available
40
- from diffusers.utils.torch_utils import maybe_allow_in_graph
41
- from torch import nn
42
-
43
- from allegro.models.transformers.rope import RoPE3D, PositionGetter3D
44
- from allegro.models.transformers.embedding import CombinedTimestepSizeEmbeddings
45
-
46
- if is_xformers_available():
47
- import xformers
48
- import xformers.ops
49
- else:
50
- xformers = None
51
-
52
- from diffusers.utils import logging
53
-
54
- logger = logging.get_logger(__name__)
55
-
56
-
57
- def to_2tuple(x):
58
- if isinstance(x, collections.abc.Iterable):
59
- return x
60
- return (x, x)
61
-
62
-
63
- @maybe_allow_in_graph
64
- class Attention(nn.Module):
65
- r"""
66
- A cross attention layer.
67
-
68
- Parameters:
69
- query_dim (`int`):
70
- The number of channels in the query.
71
- cross_attention_dim (`int`, *optional*):
72
- The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
73
- heads (`int`, *optional*, defaults to 8):
74
- The number of heads to use for multi-head attention.
75
- dim_head (`int`, *optional*, defaults to 64):
76
- The number of channels in each head.
77
- dropout (`float`, *optional*, defaults to 0.0):
78
- The dropout probability to use.
79
- bias (`bool`, *optional*, defaults to False):
80
- Set to `True` for the query, key, and value linear layers to contain a bias parameter.
81
- upcast_attention (`bool`, *optional*, defaults to False):
82
- Set to `True` to upcast the attention computation to `float32`.
83
- upcast_softmax (`bool`, *optional*, defaults to False):
84
- Set to `True` to upcast the softmax computation to `float32`.
85
- cross_attention_norm (`str`, *optional*, defaults to `None`):
86
- The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
87
- cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
88
- The number of groups to use for the group norm in the cross attention.
89
- added_kv_proj_dim (`int`, *optional*, defaults to `None`):
90
- The number of channels to use for the added key and value projections. If `None`, no projection is used.
91
- norm_num_groups (`int`, *optional*, defaults to `None`):
92
- The number of groups to use for the group norm in the attention.
93
- spatial_norm_dim (`int`, *optional*, defaults to `None`):
94
- The number of channels to use for the spatial normalization.
95
- out_bias (`bool`, *optional*, defaults to `True`):
96
- Set to `True` to use a bias in the output linear layer.
97
- scale_qk (`bool`, *optional*, defaults to `True`):
98
- Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
99
- only_cross_attention (`bool`, *optional*, defaults to `False`):
100
- Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
101
- `added_kv_proj_dim` is not `None`.
102
- eps (`float`, *optional*, defaults to 1e-5):
103
- An additional value added to the denominator in group normalization that is used for numerical stability.
104
- rescale_output_factor (`float`, *optional*, defaults to 1.0):
105
- A factor to rescale the output by dividing it with this value.
106
- residual_connection (`bool`, *optional*, defaults to `False`):
107
- Set to `True` to add the residual connection to the output.
108
- _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
109
- Set to `True` if the attention block is loaded from a deprecated state dict.
110
- processor (`AttnProcessor`, *optional*, defaults to `None`):
111
- The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
112
- `AttnProcessor` otherwise.
113
- """
114
-
115
- def __init__(
116
- self,
117
- query_dim: int,
118
- cross_attention_dim: Optional[int] = None,
119
- heads: int = 8,
120
- dim_head: int = 64,
121
- dropout: float = 0.0,
122
- bias: bool = False,
123
- upcast_attention: bool = False,
124
- upcast_softmax: bool = False,
125
- cross_attention_norm: Optional[str] = None,
126
- cross_attention_norm_num_groups: int = 32,
127
- added_kv_proj_dim: Optional[int] = None,
128
- norm_num_groups: Optional[int] = None,
129
- spatial_norm_dim: Optional[int] = None,
130
- out_bias: bool = True,
131
- scale_qk: bool = True,
132
- only_cross_attention: bool = False,
133
- eps: float = 1e-5,
134
- rescale_output_factor: float = 1.0,
135
- residual_connection: bool = False,
136
- _from_deprecated_attn_block: bool = False,
137
- processor: Optional["AttnProcessor"] = None,
138
- attention_mode: str = "xformers",
139
- use_rope: bool = False,
140
- interpolation_scale_thw=None,
141
- ):
142
- super().__init__()
143
- self.inner_dim = dim_head * heads
144
- self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
145
- self.upcast_attention = upcast_attention
146
- self.upcast_softmax = upcast_softmax
147
- self.rescale_output_factor = rescale_output_factor
148
- self.residual_connection = residual_connection
149
- self.dropout = dropout
150
- self.use_rope = use_rope
151
-
152
- # we make use of this private variable to know whether this class is loaded
153
- # with an deprecated state dict so that we can convert it on the fly
154
- self._from_deprecated_attn_block = _from_deprecated_attn_block
155
-
156
- self.scale_qk = scale_qk
157
- self.scale = dim_head**-0.5 if self.scale_qk else 1.0
158
-
159
- self.heads = heads
160
- # for slice_size > 0 the attention score computation
161
- # is split across the batch axis to save memory
162
- # You can set slice_size with `set_attention_slice`
163
- self.sliceable_head_dim = heads
164
-
165
- self.added_kv_proj_dim = added_kv_proj_dim
166
- self.only_cross_attention = only_cross_attention
167
-
168
- if self.added_kv_proj_dim is None and self.only_cross_attention:
169
- raise ValueError(
170
- "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
171
- )
172
-
173
- if norm_num_groups is not None:
174
- self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
175
- else:
176
- self.group_norm = None
177
-
178
- if spatial_norm_dim is not None:
179
- self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
180
- else:
181
- self.spatial_norm = None
182
-
183
- if cross_attention_norm is None:
184
- self.norm_cross = None
185
- elif cross_attention_norm == "layer_norm":
186
- self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
187
- elif cross_attention_norm == "group_norm":
188
- if self.added_kv_proj_dim is not None:
189
- # The given `encoder_hidden_states` are initially of shape
190
- # (batch_size, seq_len, added_kv_proj_dim) before being projected
191
- # to (batch_size, seq_len, cross_attention_dim). The norm is applied
192
- # before the projection, so we need to use `added_kv_proj_dim` as
193
- # the number of channels for the group norm.
194
- norm_cross_num_channels = added_kv_proj_dim
195
- else:
196
- norm_cross_num_channels = self.cross_attention_dim
197
-
198
- self.norm_cross = nn.GroupNorm(
199
- num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
200
- )
201
- else:
202
- raise ValueError(
203
- f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
204
- )
205
-
206
- linear_cls = nn.Linear
207
-
208
-
209
- self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
210
-
211
- if not self.only_cross_attention:
212
- # only relevant for the `AddedKVProcessor` classes
213
- self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
214
- self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
215
- else:
216
- self.to_k = None
217
- self.to_v = None
218
-
219
- if self.added_kv_proj_dim is not None:
220
- self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
221
- self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
222
-
223
- self.to_out = nn.ModuleList([])
224
- self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
225
- self.to_out.append(nn.Dropout(dropout))
226
-
227
- # set attention processor
228
- # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
229
- # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
230
- # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
231
- if processor is None:
232
- processor = (
233
- AttnProcessor2_0(
234
- attention_mode,
235
- use_rope,
236
- interpolation_scale_thw=interpolation_scale_thw,
237
- )
238
- if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
239
- else AttnProcessor()
240
- )
241
- self.set_processor(processor)
242
-
243
- def set_use_memory_efficient_attention_xformers(
244
- self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
245
- ) -> None:
246
- r"""
247
- Set whether to use memory efficient attention from `xformers` or not.
248
-
249
- Args:
250
- use_memory_efficient_attention_xformers (`bool`):
251
- Whether to use memory efficient attention from `xformers` or not.
252
- attention_op (`Callable`, *optional*):
253
- The attention operation to use. Defaults to `None` which uses the default attention operation from
254
- `xformers`.
255
- """
256
- is_lora = hasattr(self, "processor")
257
- is_custom_diffusion = hasattr(self, "processor") and isinstance(
258
- self.processor,
259
- (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
260
- )
261
- is_added_kv_processor = hasattr(self, "processor") and isinstance(
262
- self.processor,
263
- (
264
- AttnAddedKVProcessor,
265
- AttnAddedKVProcessor2_0,
266
- SlicedAttnAddedKVProcessor,
267
- XFormersAttnAddedKVProcessor,
268
- LoRAAttnAddedKVProcessor,
269
- ),
270
- )
271
-
272
- if use_memory_efficient_attention_xformers:
273
- if is_added_kv_processor and (is_lora or is_custom_diffusion):
274
- raise NotImplementedError(
275
- f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
276
- )
277
- if not is_xformers_available():
278
- raise ModuleNotFoundError(
279
- (
280
- "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
281
- " xformers"
282
- ),
283
- name="xformers",
284
- )
285
- elif not torch.cuda.is_available():
286
- raise ValueError(
287
- "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
288
- " only available for GPU "
289
- )
290
- else:
291
- try:
292
- # Make sure we can run the memory efficient attention
293
- _ = xformers.ops.memory_efficient_attention(
294
- torch.randn((1, 2, 40), device="cuda"),
295
- torch.randn((1, 2, 40), device="cuda"),
296
- torch.randn((1, 2, 40), device="cuda"),
297
- )
298
- except Exception as e:
299
- raise e
300
-
301
- if is_lora:
302
- # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
303
- # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
304
- processor = LoRAXFormersAttnProcessor(
305
- hidden_size=self.processor.hidden_size,
306
- cross_attention_dim=self.processor.cross_attention_dim,
307
- rank=self.processor.rank,
308
- attention_op=attention_op,
309
- )
310
- processor.load_state_dict(self.processor.state_dict())
311
- processor.to(self.processor.to_q_lora.up.weight.device)
312
- elif is_custom_diffusion:
313
- processor = CustomDiffusionXFormersAttnProcessor(
314
- train_kv=self.processor.train_kv,
315
- train_q_out=self.processor.train_q_out,
316
- hidden_size=self.processor.hidden_size,
317
- cross_attention_dim=self.processor.cross_attention_dim,
318
- attention_op=attention_op,
319
- )
320
- processor.load_state_dict(self.processor.state_dict())
321
- if hasattr(self.processor, "to_k_custom_diffusion"):
322
- processor.to(self.processor.to_k_custom_diffusion.weight.device)
323
- elif is_added_kv_processor:
324
- # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
325
- # which uses this type of cross attention ONLY because the attention mask of format
326
- # [0, ..., -10.000, ..., 0, ...,] is not supported
327
- # throw warning
328
- logger.info(
329
- "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
330
- )
331
- processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
332
- else:
333
- processor = XFormersAttnProcessor(attention_op=attention_op)
334
- else:
335
- if is_lora:
336
- attn_processor_class = (
337
- LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
338
- )
339
- processor = attn_processor_class(
340
- hidden_size=self.processor.hidden_size,
341
- cross_attention_dim=self.processor.cross_attention_dim,
342
- rank=self.processor.rank,
343
- )
344
- processor.load_state_dict(self.processor.state_dict())
345
- processor.to(self.processor.to_q_lora.up.weight.device)
346
- elif is_custom_diffusion:
347
- attn_processor_class = (
348
- CustomDiffusionAttnProcessor2_0
349
- if hasattr(F, "scaled_dot_product_attention")
350
- else CustomDiffusionAttnProcessor
351
- )
352
- processor = attn_processor_class(
353
- train_kv=self.processor.train_kv,
354
- train_q_out=self.processor.train_q_out,
355
- hidden_size=self.processor.hidden_size,
356
- cross_attention_dim=self.processor.cross_attention_dim,
357
- )
358
- processor.load_state_dict(self.processor.state_dict())
359
- if hasattr(self.processor, "to_k_custom_diffusion"):
360
- processor.to(self.processor.to_k_custom_diffusion.weight.device)
361
- else:
362
- # set attention processor
363
- # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
364
- # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
365
- # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
366
- processor = (
367
- AttnProcessor2_0()
368
- if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
369
- else AttnProcessor()
370
- )
371
-
372
- self.set_processor(processor)
373
-
374
- def set_attention_slice(self, slice_size: int) -> None:
375
- r"""
376
- Set the slice size for attention computation.
377
-
378
- Args:
379
- slice_size (`int`):
380
- The slice size for attention computation.
381
- """
382
- if slice_size is not None and slice_size > self.sliceable_head_dim:
383
- raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
384
-
385
- if slice_size is not None and self.added_kv_proj_dim is not None:
386
- processor = SlicedAttnAddedKVProcessor(slice_size)
387
- elif slice_size is not None:
388
- processor = SlicedAttnProcessor(slice_size)
389
- elif self.added_kv_proj_dim is not None:
390
- processor = AttnAddedKVProcessor()
391
- else:
392
- # set attention processor
393
- # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
394
- # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
395
- # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
396
- processor = (
397
- AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
398
- )
399
-
400
- self.set_processor(processor)
401
-
402
- def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None:
403
- r"""
404
- Set the attention processor to use.
405
-
406
- Args:
407
- processor (`AttnProcessor`):
408
- The attention processor to use.
409
- _remove_lora (`bool`, *optional*, defaults to `False`):
410
- Set to `True` to remove LoRA layers from the model.
411
- """
412
- if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
413
- deprecate(
414
- "set_processor to offload LoRA",
415
- "0.26.0",
416
- "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
417
- )
418
- # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
419
- # We need to remove all LoRA layers
420
- # Don't forget to remove ALL `_remove_lora` from the codebase
421
- for module in self.modules():
422
- if hasattr(module, "set_lora_layer"):
423
- module.set_lora_layer(None)
424
-
425
- # if current processor is in `self._modules` and if passed `processor` is not, we need to
426
- # pop `processor` from `self._modules`
427
- if (
428
- hasattr(self, "processor")
429
- and isinstance(self.processor, torch.nn.Module)
430
- and not isinstance(processor, torch.nn.Module)
431
- ):
432
- logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
433
- self._modules.pop("processor")
434
-
435
- self.processor = processor
436
-
437
- def get_processor(self, return_deprecated_lora: bool = False):
438
- r"""
439
- Get the attention processor in use.
440
-
441
- Args:
442
- return_deprecated_lora (`bool`, *optional*, defaults to `False`):
443
- Set to `True` to return the deprecated LoRA attention processor.
444
-
445
- Returns:
446
- "AttentionProcessor": The attention processor in use.
447
- """
448
- if not return_deprecated_lora:
449
- return self.processor
450
-
451
- # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
452
- # serialization format for LoRA Attention Processors. It should be deleted once the integration
453
- # with PEFT is completed.
454
- is_lora_activated = {
455
- name: module.lora_layer is not None
456
- for name, module in self.named_modules()
457
- if hasattr(module, "lora_layer")
458
- }
459
-
460
- # 1. if no layer has a LoRA activated we can return the processor as usual
461
- if not any(is_lora_activated.values()):
462
- return self.processor
463
-
464
- # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
465
- is_lora_activated.pop("add_k_proj", None)
466
- is_lora_activated.pop("add_v_proj", None)
467
- # 2. else it is not posssible that only some layers have LoRA activated
468
- if not all(is_lora_activated.values()):
469
- raise ValueError(
470
- f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
471
- )
472
-
473
- # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
474
- non_lora_processor_cls_name = self.processor.__class__.__name__
475
- lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
476
-
477
- hidden_size = self.inner_dim
478
-
479
- # now create a LoRA attention processor from the LoRA layers
480
- if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
481
- kwargs = {
482
- "cross_attention_dim": self.cross_attention_dim,
483
- "rank": self.to_q.lora_layer.rank,
484
- "network_alpha": self.to_q.lora_layer.network_alpha,
485
- "q_rank": self.to_q.lora_layer.rank,
486
- "q_hidden_size": self.to_q.lora_layer.out_features,
487
- "k_rank": self.to_k.lora_layer.rank,
488
- "k_hidden_size": self.to_k.lora_layer.out_features,
489
- "v_rank": self.to_v.lora_layer.rank,
490
- "v_hidden_size": self.to_v.lora_layer.out_features,
491
- "out_rank": self.to_out[0].lora_layer.rank,
492
- "out_hidden_size": self.to_out[0].lora_layer.out_features,
493
- }
494
-
495
- if hasattr(self.processor, "attention_op"):
496
- kwargs["attention_op"] = self.processor.attention_op
497
-
498
- lora_processor = lora_processor_cls(hidden_size, **kwargs)
499
- lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
500
- lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
501
- lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
502
- lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
503
- elif lora_processor_cls == LoRAAttnAddedKVProcessor:
504
- lora_processor = lora_processor_cls(
505
- hidden_size,
506
- cross_attention_dim=self.add_k_proj.weight.shape[0],
507
- rank=self.to_q.lora_layer.rank,
508
- network_alpha=self.to_q.lora_layer.network_alpha,
509
- )
510
- lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
511
- lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
512
- lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
513
- lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
514
-
515
- # only save if used
516
- if self.add_k_proj.lora_layer is not None:
517
- lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
518
- lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
519
- else:
520
- lora_processor.add_k_proj_lora = None
521
- lora_processor.add_v_proj_lora = None
522
- else:
523
- raise ValueError(f"{lora_processor_cls} does not exist.")
524
-
525
- return lora_processor
526
-
527
- def forward(
528
- self,
529
- hidden_states: torch.FloatTensor,
530
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
531
- attention_mask: Optional[torch.FloatTensor] = None,
532
- **cross_attention_kwargs,
533
- ) -> torch.Tensor:
534
- r"""
535
- The forward method of the `Attention` class.
536
-
537
- Args:
538
- hidden_states (`torch.Tensor`):
539
- The hidden states of the query.
540
- encoder_hidden_states (`torch.Tensor`, *optional*):
541
- The hidden states of the encoder.
542
- attention_mask (`torch.Tensor`, *optional*):
543
- The attention mask to use. If `None`, no mask is applied.
544
- **cross_attention_kwargs:
545
- Additional keyword arguments to pass along to the cross attention.
546
-
547
- Returns:
548
- `torch.Tensor`: The output of the attention layer.
549
- """
550
- # The `Attention` class can call different attention processors / attention functions
551
- # here we simply pass along all tensors to the selected processor class
552
- # For standard processors that are defined here, `**cross_attention_kwargs` is empty
553
- return self.processor(
554
- self,
555
- hidden_states,
556
- encoder_hidden_states=encoder_hidden_states,
557
- attention_mask=attention_mask,
558
- **cross_attention_kwargs,
559
- )
560
-
561
- def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
562
- r"""
563
- Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
564
- is the number of heads initialized while constructing the `Attention` class.
565
-
566
- Args:
567
- tensor (`torch.Tensor`): The tensor to reshape.
568
-
569
- Returns:
570
- `torch.Tensor`: The reshaped tensor.
571
- """
572
- head_size = self.heads
573
- batch_size, seq_len, dim = tensor.shape
574
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
575
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
576
- return tensor
577
-
578
- def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
579
- r"""
580
- Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
581
- the number of heads initialized while constructing the `Attention` class.
582
-
583
- Args:
584
- tensor (`torch.Tensor`): The tensor to reshape.
585
- out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
586
- reshaped to `[batch_size * heads, seq_len, dim // heads]`.
587
-
588
- Returns:
589
- `torch.Tensor`: The reshaped tensor.
590
- """
591
- head_size = self.heads
592
- batch_size, seq_len, dim = tensor.shape
593
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
594
- tensor = tensor.permute(0, 2, 1, 3)
595
-
596
- if out_dim == 3:
597
- tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
598
-
599
- return tensor
600
-
601
- def get_attention_scores(
602
- self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
603
- ) -> torch.Tensor:
604
- r"""
605
- Compute the attention scores.
606
-
607
- Args:
608
- query (`torch.Tensor`): The query tensor.
609
- key (`torch.Tensor`): The key tensor.
610
- attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
611
-
612
- Returns:
613
- `torch.Tensor`: The attention probabilities/scores.
614
- """
615
- dtype = query.dtype
616
- if self.upcast_attention:
617
- query = query.float()
618
- key = key.float()
619
-
620
- if attention_mask is None:
621
- baddbmm_input = torch.empty(
622
- query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
623
- )
624
- beta = 0
625
- else:
626
- baddbmm_input = attention_mask
627
- beta = 1
628
-
629
- attention_scores = torch.baddbmm(
630
- baddbmm_input,
631
- query,
632
- key.transpose(-1, -2),
633
- beta=beta,
634
- alpha=self.scale,
635
- )
636
- del baddbmm_input
637
-
638
- if self.upcast_softmax:
639
- attention_scores = attention_scores.float()
640
-
641
- attention_probs = attention_scores.softmax(dim=-1)
642
- del attention_scores
643
-
644
- attention_probs = attention_probs.to(dtype)
645
-
646
- return attention_probs
647
-
648
- def prepare_attention_mask(
649
- self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3, head_size = None,
650
- ) -> torch.Tensor:
651
- r"""
652
- Prepare the attention mask for the attention computation.
653
-
654
- Args:
655
- attention_mask (`torch.Tensor`):
656
- The attention mask to prepare.
657
- target_length (`int`):
658
- The target length of the attention mask. This is the length of the attention mask after padding.
659
- batch_size (`int`):
660
- The batch size, which is used to repeat the attention mask.
661
- out_dim (`int`, *optional*, defaults to `3`):
662
- The output dimension of the attention mask. Can be either `3` or `4`.
663
-
664
- Returns:
665
- `torch.Tensor`: The prepared attention mask.
666
- """
667
- head_size = head_size if head_size is not None else self.heads
668
- if attention_mask is None:
669
- return attention_mask
670
-
671
- current_length: int = attention_mask.shape[-1]
672
- if current_length != target_length:
673
- if attention_mask.device.type == "mps":
674
- # HACK: MPS: Does not support padding by greater than dimension of input tensor.
675
- # Instead, we can manually construct the padding tensor.
676
- padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
677
- padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
678
- attention_mask = torch.cat([attention_mask, padding], dim=2)
679
- else:
680
- # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
681
- # we want to instead pad by (0, remaining_length), where remaining_length is:
682
- # remaining_length: int = target_length - current_length
683
- # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
684
- attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
685
-
686
- if out_dim == 3:
687
- if attention_mask.shape[0] < batch_size * head_size:
688
- attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
689
- elif out_dim == 4:
690
- attention_mask = attention_mask.unsqueeze(1)
691
- attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
692
-
693
- return attention_mask
694
-
695
- def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
696
- r"""
697
- Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
698
- `Attention` class.
699
-
700
- Args:
701
- encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
702
-
703
- Returns:
704
- `torch.Tensor`: The normalized encoder hidden states.
705
- """
706
- assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
707
-
708
- if isinstance(self.norm_cross, nn.LayerNorm):
709
- encoder_hidden_states = self.norm_cross(encoder_hidden_states)
710
- elif isinstance(self.norm_cross, nn.GroupNorm):
711
- # Group norm norms along the channels dimension and expects
712
- # input to be in the shape of (N, C, *). In this case, we want
713
- # to norm along the hidden dimension, so we need to move
714
- # (batch_size, sequence_length, hidden_size) ->
715
- # (batch_size, hidden_size, sequence_length)
716
- encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
717
- encoder_hidden_states = self.norm_cross(encoder_hidden_states)
718
- encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
719
- else:
720
- assert False
721
-
722
- return encoder_hidden_states
723
-
724
- def _init_compress(self):
725
- self.sr.bias.data.zero_()
726
- self.norm = nn.LayerNorm(self.inner_dim)
727
-
728
-
729
- class AttnProcessor2_0(nn.Module):
730
- r"""
731
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
732
- """
733
-
734
- def __init__(self, attention_mode="xformers", use_rope=False, interpolation_scale_thw=None):
735
- super().__init__()
736
- self.attention_mode = attention_mode
737
- self.use_rope = use_rope
738
- self.interpolation_scale_thw = interpolation_scale_thw
739
-
740
- if self.use_rope:
741
- self._init_rope(interpolation_scale_thw)
742
-
743
- if not hasattr(F, "scaled_dot_product_attention"):
744
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
745
-
746
- def _init_rope(self, interpolation_scale_thw):
747
- self.rope = RoPE3D(interpolation_scale_thw=interpolation_scale_thw)
748
- self.position_getter = PositionGetter3D()
749
-
750
- def __call__(
751
- self,
752
- attn: Attention,
753
- hidden_states: torch.FloatTensor,
754
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
755
- attention_mask: Optional[torch.FloatTensor] = None,
756
- temb: Optional[torch.FloatTensor] = None,
757
- frame: int = 8,
758
- height: int = 16,
759
- width: int = 16,
760
- ) -> torch.FloatTensor:
761
-
762
- residual = hidden_states
763
-
764
- if attn.spatial_norm is not None:
765
- hidden_states = attn.spatial_norm(hidden_states, temb)
766
-
767
- input_ndim = hidden_states.ndim
768
-
769
- if input_ndim == 4:
770
- batch_size, channel, height, width = hidden_states.shape
771
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
772
-
773
-
774
- batch_size, sequence_length, _ = (
775
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
776
- )
777
-
778
- if attention_mask is not None and self.attention_mode == 'xformers':
779
- attention_heads = attn.heads
780
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, head_size=attention_heads)
781
- attention_mask = attention_mask.view(batch_size, attention_heads, -1, attention_mask.shape[-1])
782
- else:
783
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
784
- # scaled_dot_product_attention expects attention_mask shape to be
785
- # (batch, heads, source_length, target_length)
786
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
787
-
788
- if attn.group_norm is not None:
789
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
790
-
791
- query = attn.to_q(hidden_states)
792
-
793
- if encoder_hidden_states is None:
794
- encoder_hidden_states = hidden_states
795
- elif attn.norm_cross:
796
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
797
-
798
- key = attn.to_k(encoder_hidden_states)
799
- value = attn.to_v(encoder_hidden_states)
800
-
801
-
802
-
803
- attn_heads = attn.heads
804
-
805
- inner_dim = key.shape[-1]
806
- head_dim = inner_dim // attn_heads
807
-
808
- query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
809
- key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
810
- value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
811
-
812
-
813
- if self.use_rope:
814
- # require the shape of (batch_size x nheads x ntokens x dim)
815
- pos_thw = self.position_getter(batch_size, t=frame, h=height, w=width, device=query.device)
816
- query = self.rope(query, pos_thw)
817
- key = self.rope(key, pos_thw)
818
-
819
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
820
- # TODO: add support for attn.scale when we move to Torch 2.1
821
- if self.attention_mode == 'flash':
822
- # assert attention_mask is None, 'flash-attn do not support attention_mask'
823
- with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
824
- hidden_states = F.scaled_dot_product_attention(
825
- query, key, value, dropout_p=0.0, is_causal=False
826
- )
827
- elif self.attention_mode == 'xformers':
828
- with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
829
- hidden_states = F.scaled_dot_product_attention(
830
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
831
- )
832
-
833
-
834
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
835
- hidden_states = hidden_states.to(query.dtype)
836
-
837
- # linear proj
838
- hidden_states = attn.to_out[0](hidden_states)
839
- # dropout
840
- hidden_states = attn.to_out[1](hidden_states)
841
-
842
- if input_ndim == 4:
843
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
844
-
845
- if attn.residual_connection:
846
- hidden_states = hidden_states + residual
847
-
848
- hidden_states = hidden_states / attn.rescale_output_factor
849
-
850
- return hidden_states
851
-
852
- class FeedForward(nn.Module):
853
- r"""
854
- A feed-forward layer.
855
-
856
- Parameters:
857
- dim (`int`): The number of channels in the input.
858
- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
859
- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
860
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
861
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
862
- final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
863
- """
864
-
865
- def __init__(
866
- self,
867
- dim: int,
868
- dim_out: Optional[int] = None,
869
- mult: int = 4,
870
- dropout: float = 0.0,
871
- activation_fn: str = "geglu",
872
- final_dropout: bool = False,
873
- ):
874
- super().__init__()
875
- inner_dim = int(dim * mult)
876
- dim_out = dim_out if dim_out is not None else dim
877
- linear_cls = nn.Linear
878
-
879
- if activation_fn == "gelu":
880
- act_fn = GELU(dim, inner_dim)
881
- if activation_fn == "gelu-approximate":
882
- act_fn = GELU(dim, inner_dim, approximate="tanh")
883
- elif activation_fn == "geglu":
884
- act_fn = GEGLU(dim, inner_dim)
885
- elif activation_fn == "geglu-approximate":
886
- act_fn = ApproximateGELU(dim, inner_dim)
887
-
888
- self.net = nn.ModuleList([])
889
- # project in
890
- self.net.append(act_fn)
891
- # project dropout
892
- self.net.append(nn.Dropout(dropout))
893
- # project out
894
- self.net.append(linear_cls(inner_dim, dim_out))
895
- # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
896
- if final_dropout:
897
- self.net.append(nn.Dropout(dropout))
898
-
899
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
900
- for module in self.net:
901
- hidden_states = module(hidden_states)
902
- return hidden_states
903
-
904
-
905
- @maybe_allow_in_graph
906
- class BasicTransformerBlock(nn.Module):
907
- r"""
908
- A basic Transformer block.
909
-
910
- Parameters:
911
- dim (`int`): The number of channels in the input and output.
912
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
913
- attention_head_dim (`int`): The number of channels in each head.
914
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
915
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
916
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
917
- num_embeds_ada_norm (:
918
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
919
- attention_bias (:
920
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
921
- only_cross_attention (`bool`, *optional*):
922
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
923
- double_self_attention (`bool`, *optional*):
924
- Whether to use two self-attention layers. In this case no cross attention layers are used.
925
- upcast_attention (`bool`, *optional*):
926
- Whether to upcast the attention computation to float32. This is useful for mixed precision training.
927
- norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
928
- Whether to use learnable elementwise affine parameters for normalization.
929
- norm_type (`str`, *optional*, defaults to `"layer_norm"`):
930
- The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
931
- final_dropout (`bool` *optional*, defaults to False):
932
- Whether to apply a final dropout after the last feed-forward layer.
933
- positional_embeddings (`str`, *optional*, defaults to `None`):
934
- The type of positional embeddings to apply to.
935
- num_positional_embeddings (`int`, *optional*, defaults to `None`):
936
- The maximum number of positional embeddings to apply.
937
- """
938
-
939
- def __init__(
940
- self,
941
- dim: int,
942
- num_attention_heads: int,
943
- attention_head_dim: int,
944
- dropout=0.0,
945
- cross_attention_dim: Optional[int] = None,
946
- activation_fn: str = "geglu",
947
- num_embeds_ada_norm: Optional[int] = None,
948
- attention_bias: bool = False,
949
- only_cross_attention: bool = False,
950
- double_self_attention: bool = False,
951
- upcast_attention: bool = False,
952
- norm_elementwise_affine: bool = True,
953
- norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
954
- norm_eps: float = 1e-5,
955
- final_dropout: bool = False,
956
- positional_embeddings: Optional[str] = None,
957
- num_positional_embeddings: Optional[int] = None,
958
- sa_attention_mode: str = "flash",
959
- ca_attention_mode: str = "xformers",
960
- use_rope: bool = False,
961
- interpolation_scale_thw: Tuple[int] = (1, 1, 1),
962
- block_idx: Optional[int] = None,
963
- ):
964
- super().__init__()
965
- self.only_cross_attention = only_cross_attention
966
-
967
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
968
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
969
- self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
970
- self.use_layer_norm = norm_type == "layer_norm"
971
-
972
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
973
- raise ValueError(
974
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
975
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
976
- )
977
-
978
- if positional_embeddings and (num_positional_embeddings is None):
979
- raise ValueError(
980
- "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
981
- )
982
-
983
- if positional_embeddings == "sinusoidal":
984
- self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
985
- else:
986
- self.pos_embed = None
987
-
988
- # Define 3 blocks. Each block has its own normalization layer.
989
- # 1. Self-Attn
990
- if self.use_ada_layer_norm:
991
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
992
- elif self.use_ada_layer_norm_zero:
993
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
994
- else:
995
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
996
-
997
- self.attn1 = Attention(
998
- query_dim=dim,
999
- heads=num_attention_heads,
1000
- dim_head=attention_head_dim,
1001
- dropout=dropout,
1002
- bias=attention_bias,
1003
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
1004
- upcast_attention=upcast_attention,
1005
- attention_mode=sa_attention_mode,
1006
- use_rope=use_rope,
1007
- interpolation_scale_thw=interpolation_scale_thw,
1008
- )
1009
-
1010
- # 2. Cross-Attn
1011
- if cross_attention_dim is not None or double_self_attention:
1012
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
1013
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
1014
- # the second cross attention block.
1015
- self.norm2 = (
1016
- AdaLayerNorm(dim, num_embeds_ada_norm)
1017
- if self.use_ada_layer_norm
1018
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1019
- )
1020
- self.attn2 = Attention(
1021
- query_dim=dim,
1022
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
1023
- heads=num_attention_heads,
1024
- dim_head=attention_head_dim,
1025
- dropout=dropout,
1026
- bias=attention_bias,
1027
- upcast_attention=upcast_attention,
1028
- attention_mode=ca_attention_mode, # only xformers support attention_mask
1029
- use_rope=False, # do not position in cross attention
1030
- interpolation_scale_thw=interpolation_scale_thw,
1031
- ) # is self-attn if encoder_hidden_states is none
1032
- else:
1033
- self.norm2 = None
1034
- self.attn2 = None
1035
-
1036
- # 3. Feed-forward
1037
-
1038
- if not self.use_ada_layer_norm_single:
1039
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1040
-
1041
- self.ff = FeedForward(
1042
- dim,
1043
- dropout=dropout,
1044
- activation_fn=activation_fn,
1045
- final_dropout=final_dropout,
1046
- )
1047
-
1048
- # 5. Scale-shift for PixArt-Alpha.
1049
- if self.use_ada_layer_norm_single:
1050
- self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
1051
-
1052
-
1053
- def forward(
1054
- self,
1055
- hidden_states: torch.FloatTensor,
1056
- attention_mask: Optional[torch.FloatTensor] = None,
1057
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1058
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1059
- timestep: Optional[torch.LongTensor] = None,
1060
- cross_attention_kwargs: Dict[str, Any] = None,
1061
- class_labels: Optional[torch.LongTensor] = None,
1062
- frame: int = None,
1063
- height: int = None,
1064
- width: int = None,
1065
- ) -> torch.FloatTensor:
1066
- # Notice that normalization is always applied before the real computation in the following blocks.
1067
- cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
1068
-
1069
- # 0. Self-Attention
1070
- batch_size = hidden_states.shape[0]
1071
-
1072
- if self.use_ada_layer_norm:
1073
- norm_hidden_states = self.norm1(hidden_states, timestep)
1074
- elif self.use_ada_layer_norm_zero:
1075
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
1076
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
1077
- )
1078
- elif self.use_layer_norm:
1079
- norm_hidden_states = self.norm1(hidden_states)
1080
- elif self.use_ada_layer_norm_single:
1081
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
1082
- self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
1083
- ).chunk(6, dim=1)
1084
- norm_hidden_states = self.norm1(hidden_states)
1085
- norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
1086
- norm_hidden_states = norm_hidden_states.squeeze(1)
1087
- else:
1088
- raise ValueError("Incorrect norm used")
1089
-
1090
- if self.pos_embed is not None:
1091
- norm_hidden_states = self.pos_embed(norm_hidden_states)
1092
-
1093
- attn_output = self.attn1(
1094
- norm_hidden_states,
1095
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1096
- attention_mask=attention_mask,
1097
- frame=frame,
1098
- height=height,
1099
- width=width,
1100
- **cross_attention_kwargs,
1101
- )
1102
- if self.use_ada_layer_norm_zero:
1103
- attn_output = gate_msa.unsqueeze(1) * attn_output
1104
- elif self.use_ada_layer_norm_single:
1105
- attn_output = gate_msa * attn_output
1106
-
1107
- hidden_states = attn_output + hidden_states
1108
- if hidden_states.ndim == 4:
1109
- hidden_states = hidden_states.squeeze(1)
1110
-
1111
- # 1. Cross-Attention
1112
- if self.attn2 is not None:
1113
-
1114
- if self.use_ada_layer_norm:
1115
- norm_hidden_states = self.norm2(hidden_states, timestep)
1116
- elif self.use_ada_layer_norm_zero or self.use_layer_norm:
1117
- norm_hidden_states = self.norm2(hidden_states)
1118
- elif self.use_ada_layer_norm_single:
1119
- # For PixArt norm2 isn't applied here:
1120
- # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
1121
- norm_hidden_states = hidden_states
1122
- else:
1123
- raise ValueError("Incorrect norm")
1124
-
1125
- if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
1126
- norm_hidden_states = self.pos_embed(norm_hidden_states)
1127
-
1128
- attn_output = self.attn2(
1129
- norm_hidden_states,
1130
- encoder_hidden_states=encoder_hidden_states,
1131
- attention_mask=encoder_attention_mask,
1132
- **cross_attention_kwargs,
1133
- )
1134
- hidden_states = attn_output + hidden_states
1135
-
1136
-
1137
- # 2. Feed-forward
1138
- if not self.use_ada_layer_norm_single:
1139
- norm_hidden_states = self.norm3(hidden_states)
1140
-
1141
- if self.use_ada_layer_norm_zero:
1142
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
1143
-
1144
- if self.use_ada_layer_norm_single:
1145
- norm_hidden_states = self.norm2(hidden_states)
1146
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
1147
-
1148
- ff_output = self.ff(norm_hidden_states)
1149
-
1150
- if self.use_ada_layer_norm_zero:
1151
- ff_output = gate_mlp.unsqueeze(1) * ff_output
1152
- elif self.use_ada_layer_norm_single:
1153
- ff_output = gate_mlp * ff_output
1154
-
1155
-
1156
- hidden_states = ff_output + hidden_states
1157
- if hidden_states.ndim == 4:
1158
- hidden_states = hidden_states.squeeze(1)
1159
-
1160
- return hidden_states
1161
-
1162
-
1163
- class AdaLayerNormSingle(nn.Module):
1164
- r"""
1165
- Norm layer adaptive layer norm single (adaLN-single).
1166
-
1167
- As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
1168
-
1169
- Parameters:
1170
- embedding_dim (`int`): The size of each embedding vector.
1171
- use_additional_conditions (`bool`): To use additional conditions for normalization or not.
1172
- """
1173
-
1174
- def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
1175
- super().__init__()
1176
-
1177
- self.emb = CombinedTimestepSizeEmbeddings(
1178
- embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
1179
- )
1180
-
1181
- self.silu = nn.SiLU()
1182
- self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
1183
-
1184
- def forward(
1185
- self,
1186
- timestep: torch.Tensor,
1187
- added_cond_kwargs: Dict[str, torch.Tensor] = None,
1188
- batch_size: int = None,
1189
- hidden_dtype: Optional[torch.dtype] = None,
1190
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1191
- # No modulation happening here.
1192
- embedded_timestep = self.emb(
1193
- timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None
1194
- )
1195
- return self.linear(self.silu(embedded_timestep)), embedded_timestep