AI-Anchorite commited on
Commit
3ed9da5
1 Parent(s): 4e18f8a

Upload block.py

Browse files
Files changed (1) hide show
  1. allegro/models/transformers/block.py +1200 -0
allegro/models/transformers/block.py ADDED
@@ -0,0 +1,1200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from Open-Sora-Plan
2
+
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # --------------------------------------------------------
6
+ # References:
7
+ # Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
8
+ # --------------------------------------------------------
9
+
10
+
11
+ from importlib import import_module
12
+ from typing import Any, Callable, Dict, Optional, Tuple
13
+
14
+ import numpy as np
15
+ import torch
16
+ import collections
17
+ import torch.nn.functional as F
18
+ from torch.nn.attention import SDPBackend, sdpa_kernel
19
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
20
+ from diffusers.models.attention_processor import (
21
+ AttnAddedKVProcessor,
22
+ AttnAddedKVProcessor2_0,
23
+ AttnProcessor,
24
+ CustomDiffusionAttnProcessor,
25
+ CustomDiffusionAttnProcessor2_0,
26
+ CustomDiffusionXFormersAttnProcessor,
27
+ LoRAAttnAddedKVProcessor,
28
+ LoRAAttnProcessor,
29
+ LoRAAttnProcessor2_0,
30
+ LoRAXFormersAttnProcessor,
31
+ SlicedAttnAddedKVProcessor,
32
+ SlicedAttnProcessor,
33
+ SpatialNorm,
34
+ XFormersAttnAddedKVProcessor,
35
+ XFormersAttnProcessor,
36
+ )
37
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
38
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
39
+ from diffusers.utils import USE_PEFT_BACKEND, deprecate, is_xformers_available
40
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
41
+ from torch import nn
42
+
43
+ from allegro.models.transformers.rope import RoPE3D, PositionGetter3D
44
+ from allegro.models.transformers.embedding import CombinedTimestepSizeEmbeddings
45
+
46
+ if is_xformers_available():
47
+ import xformers
48
+ import xformers.ops
49
+ else:
50
+ xformers = None
51
+
52
+ from diffusers.utils import logging
53
+
54
+ logger = logging.get_logger(__name__)
55
+
56
+
57
+ def to_2tuple(x):
58
+ if isinstance(x, collections.abc.Iterable):
59
+ return x
60
+ return (x, x)
61
+
62
+
63
+ @maybe_allow_in_graph
64
+ class Attention(nn.Module):
65
+ r"""
66
+ A cross attention layer.
67
+
68
+ Parameters:
69
+ query_dim (`int`):
70
+ The number of channels in the query.
71
+ cross_attention_dim (`int`, *optional*):
72
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
73
+ heads (`int`, *optional*, defaults to 8):
74
+ The number of heads to use for multi-head attention.
75
+ dim_head (`int`, *optional*, defaults to 64):
76
+ The number of channels in each head.
77
+ dropout (`float`, *optional*, defaults to 0.0):
78
+ The dropout probability to use.
79
+ bias (`bool`, *optional*, defaults to False):
80
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
81
+ upcast_attention (`bool`, *optional*, defaults to False):
82
+ Set to `True` to upcast the attention computation to `float32`.
83
+ upcast_softmax (`bool`, *optional*, defaults to False):
84
+ Set to `True` to upcast the softmax computation to `float32`.
85
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
86
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
87
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
88
+ The number of groups to use for the group norm in the cross attention.
89
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
90
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
91
+ norm_num_groups (`int`, *optional*, defaults to `None`):
92
+ The number of groups to use for the group norm in the attention.
93
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
94
+ The number of channels to use for the spatial normalization.
95
+ out_bias (`bool`, *optional*, defaults to `True`):
96
+ Set to `True` to use a bias in the output linear layer.
97
+ scale_qk (`bool`, *optional*, defaults to `True`):
98
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
99
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
100
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
101
+ `added_kv_proj_dim` is not `None`.
102
+ eps (`float`, *optional*, defaults to 1e-5):
103
+ An additional value added to the denominator in group normalization that is used for numerical stability.
104
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
105
+ A factor to rescale the output by dividing it with this value.
106
+ residual_connection (`bool`, *optional*, defaults to `False`):
107
+ Set to `True` to add the residual connection to the output.
108
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
109
+ Set to `True` if the attention block is loaded from a deprecated state dict.
110
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
111
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
112
+ `AttnProcessor` otherwise.
113
+ """
114
+
115
+ def __init__(
116
+ self,
117
+ query_dim: int,
118
+ cross_attention_dim: Optional[int] = None,
119
+ heads: int = 8,
120
+ dim_head: int = 64,
121
+ dropout: float = 0.0,
122
+ bias: bool = False,
123
+ upcast_attention: bool = False,
124
+ upcast_softmax: bool = False,
125
+ cross_attention_norm: Optional[str] = None,
126
+ cross_attention_norm_num_groups: int = 32,
127
+ added_kv_proj_dim: Optional[int] = None,
128
+ norm_num_groups: Optional[int] = None,
129
+ spatial_norm_dim: Optional[int] = None,
130
+ out_bias: bool = True,
131
+ scale_qk: bool = True,
132
+ only_cross_attention: bool = False,
133
+ eps: float = 1e-5,
134
+ rescale_output_factor: float = 1.0,
135
+ residual_connection: bool = False,
136
+ _from_deprecated_attn_block: bool = False,
137
+ processor: Optional["AttnProcessor"] = None,
138
+ attention_mode: str = "xformers",
139
+ use_rope: bool = False,
140
+ interpolation_scale_thw=None,
141
+ ):
142
+ super().__init__()
143
+ self.inner_dim = dim_head * heads
144
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
145
+ self.upcast_attention = upcast_attention
146
+ self.upcast_softmax = upcast_softmax
147
+ self.rescale_output_factor = rescale_output_factor
148
+ self.residual_connection = residual_connection
149
+ self.dropout = dropout
150
+ self.use_rope = use_rope
151
+
152
+ # we make use of this private variable to know whether this class is loaded
153
+ # with an deprecated state dict so that we can convert it on the fly
154
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
155
+
156
+ self.scale_qk = scale_qk
157
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
158
+
159
+ self.heads = heads
160
+ # for slice_size > 0 the attention score computation
161
+ # is split across the batch axis to save memory
162
+ # You can set slice_size with `set_attention_slice`
163
+ self.sliceable_head_dim = heads
164
+
165
+ self.added_kv_proj_dim = added_kv_proj_dim
166
+ self.only_cross_attention = only_cross_attention
167
+
168
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
169
+ raise ValueError(
170
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
171
+ )
172
+
173
+ if norm_num_groups is not None:
174
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
175
+ else:
176
+ self.group_norm = None
177
+
178
+ if spatial_norm_dim is not None:
179
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
180
+ else:
181
+ self.spatial_norm = None
182
+
183
+ if cross_attention_norm is None:
184
+ self.norm_cross = None
185
+ elif cross_attention_norm == "layer_norm":
186
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
187
+ elif cross_attention_norm == "group_norm":
188
+ if self.added_kv_proj_dim is not None:
189
+ # The given `encoder_hidden_states` are initially of shape
190
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
191
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
192
+ # before the projection, so we need to use `added_kv_proj_dim` as
193
+ # the number of channels for the group norm.
194
+ norm_cross_num_channels = added_kv_proj_dim
195
+ else:
196
+ norm_cross_num_channels = self.cross_attention_dim
197
+
198
+ self.norm_cross = nn.GroupNorm(
199
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
200
+ )
201
+ else:
202
+ raise ValueError(
203
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
204
+ )
205
+
206
+ linear_cls = nn.Linear
207
+
208
+
209
+ self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
210
+
211
+ if not self.only_cross_attention:
212
+ # only relevant for the `AddedKVProcessor` classes
213
+ self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
214
+ self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
215
+ else:
216
+ self.to_k = None
217
+ self.to_v = None
218
+
219
+ if self.added_kv_proj_dim is not None:
220
+ self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
221
+ self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
222
+
223
+ self.to_out = nn.ModuleList([])
224
+ self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
225
+ self.to_out.append(nn.Dropout(dropout))
226
+
227
+ # set attention processor
228
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
229
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
230
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
231
+ if processor is None:
232
+ processor = (
233
+ AttnProcessor2_0(
234
+ attention_mode,
235
+ use_rope,
236
+ interpolation_scale_thw=interpolation_scale_thw,
237
+ )
238
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
239
+ else AttnProcessor()
240
+ )
241
+ self.set_processor(processor)
242
+
243
+ def set_use_memory_efficient_attention_xformers(
244
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
245
+ ) -> None:
246
+ r"""
247
+ Set whether to use memory efficient attention from `xformers` or not.
248
+
249
+ Args:
250
+ use_memory_efficient_attention_xformers (`bool`):
251
+ Whether to use memory efficient attention from `xformers` or not.
252
+ attention_op (`Callable`, *optional*):
253
+ The attention operation to use. Defaults to `None` which uses the default attention operation from
254
+ `xformers`.
255
+ """
256
+ is_lora = hasattr(self, "processor")
257
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
258
+ self.processor,
259
+ (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
260
+ )
261
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
262
+ self.processor,
263
+ (
264
+ AttnAddedKVProcessor,
265
+ AttnAddedKVProcessor2_0,
266
+ SlicedAttnAddedKVProcessor,
267
+ XFormersAttnAddedKVProcessor,
268
+ LoRAAttnAddedKVProcessor,
269
+ ),
270
+ )
271
+
272
+ if use_memory_efficient_attention_xformers:
273
+ if is_added_kv_processor and (is_lora or is_custom_diffusion):
274
+ raise NotImplementedError(
275
+ f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
276
+ )
277
+ if not is_xformers_available():
278
+ raise ModuleNotFoundError(
279
+ (
280
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
281
+ " xformers"
282
+ ),
283
+ name="xformers",
284
+ )
285
+ elif not torch.cuda.is_available():
286
+ raise ValueError(
287
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
288
+ " only available for GPU "
289
+ )
290
+ else:
291
+ try:
292
+ # Make sure we can run the memory efficient attention
293
+ _ = xformers.ops.memory_efficient_attention(
294
+ torch.randn((1, 2, 40), device="cuda"),
295
+ torch.randn((1, 2, 40), device="cuda"),
296
+ torch.randn((1, 2, 40), device="cuda"),
297
+ )
298
+ except Exception as e:
299
+ raise e
300
+
301
+ if is_lora:
302
+ # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
303
+ # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
304
+ processor = LoRAXFormersAttnProcessor(
305
+ hidden_size=self.processor.hidden_size,
306
+ cross_attention_dim=self.processor.cross_attention_dim,
307
+ rank=self.processor.rank,
308
+ attention_op=attention_op,
309
+ )
310
+ processor.load_state_dict(self.processor.state_dict())
311
+ processor.to(self.processor.to_q_lora.up.weight.device)
312
+ elif is_custom_diffusion:
313
+ processor = CustomDiffusionXFormersAttnProcessor(
314
+ train_kv=self.processor.train_kv,
315
+ train_q_out=self.processor.train_q_out,
316
+ hidden_size=self.processor.hidden_size,
317
+ cross_attention_dim=self.processor.cross_attention_dim,
318
+ attention_op=attention_op,
319
+ )
320
+ processor.load_state_dict(self.processor.state_dict())
321
+ if hasattr(self.processor, "to_k_custom_diffusion"):
322
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
323
+ elif is_added_kv_processor:
324
+ # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
325
+ # which uses this type of cross attention ONLY because the attention mask of format
326
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
327
+ # throw warning
328
+ logger.info(
329
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
330
+ )
331
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
332
+ else:
333
+ processor = XFormersAttnProcessor(attention_op=attention_op)
334
+ else:
335
+ if is_lora:
336
+ attn_processor_class = (
337
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
338
+ )
339
+ processor = attn_processor_class(
340
+ hidden_size=self.processor.hidden_size,
341
+ cross_attention_dim=self.processor.cross_attention_dim,
342
+ rank=self.processor.rank,
343
+ )
344
+ processor.load_state_dict(self.processor.state_dict())
345
+ processor.to(self.processor.to_q_lora.up.weight.device)
346
+ elif is_custom_diffusion:
347
+ attn_processor_class = (
348
+ CustomDiffusionAttnProcessor2_0
349
+ if hasattr(F, "scaled_dot_product_attention")
350
+ else CustomDiffusionAttnProcessor
351
+ )
352
+ processor = attn_processor_class(
353
+ train_kv=self.processor.train_kv,
354
+ train_q_out=self.processor.train_q_out,
355
+ hidden_size=self.processor.hidden_size,
356
+ cross_attention_dim=self.processor.cross_attention_dim,
357
+ )
358
+ processor.load_state_dict(self.processor.state_dict())
359
+ if hasattr(self.processor, "to_k_custom_diffusion"):
360
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
361
+ else:
362
+ # set attention processor
363
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
364
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
365
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
366
+ processor = (
367
+ AttnProcessor2_0()
368
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
369
+ else AttnProcessor()
370
+ )
371
+
372
+ self.set_processor(processor)
373
+
374
+ def set_attention_slice(self, slice_size: int) -> None:
375
+ r"""
376
+ Set the slice size for attention computation.
377
+
378
+ Args:
379
+ slice_size (`int`):
380
+ The slice size for attention computation.
381
+ """
382
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
383
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
384
+
385
+ if slice_size is not None and self.added_kv_proj_dim is not None:
386
+ processor = SlicedAttnAddedKVProcessor(slice_size)
387
+ elif slice_size is not None:
388
+ processor = SlicedAttnProcessor(slice_size)
389
+ elif self.added_kv_proj_dim is not None:
390
+ processor = AttnAddedKVProcessor()
391
+ else:
392
+ # set attention processor
393
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
394
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
395
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
396
+ processor = (
397
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
398
+ )
399
+
400
+ self.set_processor(processor)
401
+
402
+ def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None:
403
+ r"""
404
+ Set the attention processor to use.
405
+
406
+ Args:
407
+ processor (`AttnProcessor`):
408
+ The attention processor to use.
409
+ _remove_lora (`bool`, *optional*, defaults to `False`):
410
+ Set to `True` to remove LoRA layers from the model.
411
+ """
412
+ if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
413
+ deprecate(
414
+ "set_processor to offload LoRA",
415
+ "0.26.0",
416
+ "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
417
+ )
418
+ # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
419
+ # We need to remove all LoRA layers
420
+ # Don't forget to remove ALL `_remove_lora` from the codebase
421
+ for module in self.modules():
422
+ if hasattr(module, "set_lora_layer"):
423
+ module.set_lora_layer(None)
424
+
425
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
426
+ # pop `processor` from `self._modules`
427
+ if (
428
+ hasattr(self, "processor")
429
+ and isinstance(self.processor, torch.nn.Module)
430
+ and not isinstance(processor, torch.nn.Module)
431
+ ):
432
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
433
+ self._modules.pop("processor")
434
+
435
+ self.processor = processor
436
+
437
+ def get_processor(self, return_deprecated_lora: bool = False):
438
+ r"""
439
+ Get the attention processor in use.
440
+
441
+ Args:
442
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
443
+ Set to `True` to return the deprecated LoRA attention processor.
444
+
445
+ Returns:
446
+ "AttentionProcessor": The attention processor in use.
447
+ """
448
+ if not return_deprecated_lora:
449
+ return self.processor
450
+
451
+ # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
452
+ # serialization format for LoRA Attention Processors. It should be deleted once the integration
453
+ # with PEFT is completed.
454
+ is_lora_activated = {
455
+ name: module.lora_layer is not None
456
+ for name, module in self.named_modules()
457
+ if hasattr(module, "lora_layer")
458
+ }
459
+
460
+ # 1. if no layer has a LoRA activated we can return the processor as usual
461
+ if not any(is_lora_activated.values()):
462
+ return self.processor
463
+
464
+ # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
465
+ is_lora_activated.pop("add_k_proj", None)
466
+ is_lora_activated.pop("add_v_proj", None)
467
+ # 2. else it is not posssible that only some layers have LoRA activated
468
+ if not all(is_lora_activated.values()):
469
+ raise ValueError(
470
+ f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
471
+ )
472
+
473
+ # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
474
+ non_lora_processor_cls_name = self.processor.__class__.__name__
475
+ lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
476
+
477
+ hidden_size = self.inner_dim
478
+
479
+ # now create a LoRA attention processor from the LoRA layers
480
+ if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
481
+ kwargs = {
482
+ "cross_attention_dim": self.cross_attention_dim,
483
+ "rank": self.to_q.lora_layer.rank,
484
+ "network_alpha": self.to_q.lora_layer.network_alpha,
485
+ "q_rank": self.to_q.lora_layer.rank,
486
+ "q_hidden_size": self.to_q.lora_layer.out_features,
487
+ "k_rank": self.to_k.lora_layer.rank,
488
+ "k_hidden_size": self.to_k.lora_layer.out_features,
489
+ "v_rank": self.to_v.lora_layer.rank,
490
+ "v_hidden_size": self.to_v.lora_layer.out_features,
491
+ "out_rank": self.to_out[0].lora_layer.rank,
492
+ "out_hidden_size": self.to_out[0].lora_layer.out_features,
493
+ }
494
+
495
+ if hasattr(self.processor, "attention_op"):
496
+ kwargs["attention_op"] = self.processor.attention_op
497
+
498
+ lora_processor = lora_processor_cls(hidden_size, **kwargs)
499
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
500
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
501
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
502
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
503
+ elif lora_processor_cls == LoRAAttnAddedKVProcessor:
504
+ lora_processor = lora_processor_cls(
505
+ hidden_size,
506
+ cross_attention_dim=self.add_k_proj.weight.shape[0],
507
+ rank=self.to_q.lora_layer.rank,
508
+ network_alpha=self.to_q.lora_layer.network_alpha,
509
+ )
510
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
511
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
512
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
513
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
514
+
515
+ # only save if used
516
+ if self.add_k_proj.lora_layer is not None:
517
+ lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
518
+ lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
519
+ else:
520
+ lora_processor.add_k_proj_lora = None
521
+ lora_processor.add_v_proj_lora = None
522
+ else:
523
+ raise ValueError(f"{lora_processor_cls} does not exist.")
524
+
525
+ return lora_processor
526
+
527
+ def forward(
528
+ self,
529
+ hidden_states: torch.FloatTensor,
530
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
531
+ attention_mask: Optional[torch.FloatTensor] = None,
532
+ **cross_attention_kwargs,
533
+ ) -> torch.Tensor:
534
+ r"""
535
+ The forward method of the `Attention` class.
536
+
537
+ Args:
538
+ hidden_states (`torch.Tensor`):
539
+ The hidden states of the query.
540
+ encoder_hidden_states (`torch.Tensor`, *optional*):
541
+ The hidden states of the encoder.
542
+ attention_mask (`torch.Tensor`, *optional*):
543
+ The attention mask to use. If `None`, no mask is applied.
544
+ **cross_attention_kwargs:
545
+ Additional keyword arguments to pass along to the cross attention.
546
+
547
+ Returns:
548
+ `torch.Tensor`: The output of the attention layer.
549
+ """
550
+ # The `Attention` class can call different attention processors / attention functions
551
+ # here we simply pass along all tensors to the selected processor class
552
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
553
+ return self.processor(
554
+ self,
555
+ hidden_states,
556
+ encoder_hidden_states=encoder_hidden_states,
557
+ attention_mask=attention_mask,
558
+ **cross_attention_kwargs,
559
+ )
560
+
561
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
562
+ r"""
563
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
564
+ is the number of heads initialized while constructing the `Attention` class.
565
+
566
+ Args:
567
+ tensor (`torch.Tensor`): The tensor to reshape.
568
+
569
+ Returns:
570
+ `torch.Tensor`: The reshaped tensor.
571
+ """
572
+ head_size = self.heads
573
+ batch_size, seq_len, dim = tensor.shape
574
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
575
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
576
+ return tensor
577
+
578
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
579
+ r"""
580
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
581
+ the number of heads initialized while constructing the `Attention` class.
582
+
583
+ Args:
584
+ tensor (`torch.Tensor`): The tensor to reshape.
585
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
586
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
587
+
588
+ Returns:
589
+ `torch.Tensor`: The reshaped tensor.
590
+ """
591
+ head_size = self.heads
592
+ batch_size, seq_len, dim = tensor.shape
593
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
594
+ tensor = tensor.permute(0, 2, 1, 3)
595
+
596
+ if out_dim == 3:
597
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
598
+
599
+ return tensor
600
+
601
+ def get_attention_scores(
602
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
603
+ ) -> torch.Tensor:
604
+ r"""
605
+ Compute the attention scores.
606
+
607
+ Args:
608
+ query (`torch.Tensor`): The query tensor.
609
+ key (`torch.Tensor`): The key tensor.
610
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
611
+
612
+ Returns:
613
+ `torch.Tensor`: The attention probabilities/scores.
614
+ """
615
+ dtype = query.dtype
616
+ if self.upcast_attention:
617
+ query = query.float()
618
+ key = key.float()
619
+
620
+ if attention_mask is None:
621
+ baddbmm_input = torch.empty(
622
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
623
+ )
624
+ beta = 0
625
+ else:
626
+ baddbmm_input = attention_mask
627
+ beta = 1
628
+
629
+ attention_scores = torch.baddbmm(
630
+ baddbmm_input,
631
+ query,
632
+ key.transpose(-1, -2),
633
+ beta=beta,
634
+ alpha=self.scale,
635
+ )
636
+ del baddbmm_input
637
+
638
+ if self.upcast_softmax:
639
+ attention_scores = attention_scores.float()
640
+
641
+ attention_probs = attention_scores.softmax(dim=-1)
642
+ del attention_scores
643
+
644
+ attention_probs = attention_probs.to(dtype)
645
+
646
+ return attention_probs
647
+
648
+ def prepare_attention_mask(
649
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3, head_size = None,
650
+ ) -> torch.Tensor:
651
+ r"""
652
+ Prepare the attention mask for the attention computation.
653
+
654
+ Args:
655
+ attention_mask (`torch.Tensor`):
656
+ The attention mask to prepare.
657
+ target_length (`int`):
658
+ The target length of the attention mask. This is the length of the attention mask after padding.
659
+ batch_size (`int`):
660
+ The batch size, which is used to repeat the attention mask.
661
+ out_dim (`int`, *optional*, defaults to `3`):
662
+ The output dimension of the attention mask. Can be either `3` or `4`.
663
+
664
+ Returns:
665
+ `torch.Tensor`: The prepared attention mask.
666
+ """
667
+ head_size = head_size if head_size is not None else self.heads
668
+ if attention_mask is None:
669
+ return attention_mask
670
+
671
+ current_length: int = attention_mask.shape[-1]
672
+ if current_length != target_length:
673
+ if attention_mask.device.type == "mps":
674
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
675
+ # Instead, we can manually construct the padding tensor.
676
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
677
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
678
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
679
+ else:
680
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
681
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
682
+ # remaining_length: int = target_length - current_length
683
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
684
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
685
+
686
+ if out_dim == 3:
687
+ if attention_mask.shape[0] < batch_size * head_size:
688
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
689
+ elif out_dim == 4:
690
+ attention_mask = attention_mask.unsqueeze(1)
691
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
692
+
693
+ return attention_mask
694
+
695
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
696
+ r"""
697
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
698
+ `Attention` class.
699
+
700
+ Args:
701
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
702
+
703
+ Returns:
704
+ `torch.Tensor`: The normalized encoder hidden states.
705
+ """
706
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
707
+
708
+ if isinstance(self.norm_cross, nn.LayerNorm):
709
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
710
+ elif isinstance(self.norm_cross, nn.GroupNorm):
711
+ # Group norm norms along the channels dimension and expects
712
+ # input to be in the shape of (N, C, *). In this case, we want
713
+ # to norm along the hidden dimension, so we need to move
714
+ # (batch_size, sequence_length, hidden_size) ->
715
+ # (batch_size, hidden_size, sequence_length)
716
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
717
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
718
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
719
+ else:
720
+ assert False
721
+
722
+ return encoder_hidden_states
723
+
724
+ def _init_compress(self):
725
+ self.sr.bias.data.zero_()
726
+ self.norm = nn.LayerNorm(self.inner_dim)
727
+
728
+
729
+ class AttnProcessor2_0(nn.Module):
730
+ r"""
731
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
732
+ """
733
+
734
+ def __init__(self, attention_mode="xformers", use_rope=False, interpolation_scale_thw=None):
735
+ super().__init__()
736
+ self.attention_mode = attention_mode
737
+ self.use_rope = use_rope
738
+ self.interpolation_scale_thw = interpolation_scale_thw
739
+
740
+ if self.use_rope:
741
+ self._init_rope(interpolation_scale_thw)
742
+
743
+ if not hasattr(F, "scaled_dot_product_attention"):
744
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
745
+
746
+ def _init_rope(self, interpolation_scale_thw):
747
+ self.rope = RoPE3D(interpolation_scale_thw=interpolation_scale_thw)
748
+ self.position_getter = PositionGetter3D()
749
+
750
+ def __call__(
751
+ self,
752
+ attn: Attention,
753
+ hidden_states: torch.FloatTensor,
754
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
755
+ attention_mask: Optional[torch.FloatTensor] = None,
756
+ temb: Optional[torch.FloatTensor] = None,
757
+ frame: int = 8,
758
+ height: int = 16,
759
+ width: int = 16,
760
+ ) -> torch.FloatTensor:
761
+
762
+ residual = hidden_states
763
+
764
+ if attn.spatial_norm is not None:
765
+ hidden_states = attn.spatial_norm(hidden_states, temb)
766
+
767
+ input_ndim = hidden_states.ndim
768
+
769
+ if input_ndim == 4:
770
+ batch_size, channel, height, width = hidden_states.shape
771
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
772
+
773
+
774
+ batch_size, sequence_length, _ = (
775
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
776
+ )
777
+
778
+ if attention_mask is not None and self.attention_mode == 'xformers':
779
+ attention_heads = attn.heads
780
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, head_size=attention_heads)
781
+ attention_mask = attention_mask.view(batch_size, attention_heads, -1, attention_mask.shape[-1])
782
+ else:
783
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
784
+ # scaled_dot_product_attention expects attention_mask shape to be
785
+ # (batch, heads, source_length, target_length)
786
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
787
+
788
+ if attn.group_norm is not None:
789
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
790
+
791
+ query = attn.to_q(hidden_states)
792
+
793
+ if encoder_hidden_states is None:
794
+ encoder_hidden_states = hidden_states
795
+ elif attn.norm_cross:
796
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
797
+
798
+ key = attn.to_k(encoder_hidden_states)
799
+ value = attn.to_v(encoder_hidden_states)
800
+
801
+
802
+
803
+ attn_heads = attn.heads
804
+
805
+ inner_dim = key.shape[-1]
806
+ head_dim = inner_dim // attn_heads
807
+
808
+ query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
809
+ key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
810
+ value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
811
+
812
+
813
+ if self.use_rope:
814
+ # require the shape of (batch_size x nheads x ntokens x dim)
815
+ pos_thw = self.position_getter(batch_size, t=frame, h=height, w=width, device=query.device)
816
+ query = self.rope(query, pos_thw)
817
+ key = self.rope(key, pos_thw)
818
+
819
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
820
+ # TODO: add support for attn.scale when we move to Torch 2.1
821
+ # if self.attention_mode == 'flash':
822
+ # # assert attention_mask is None, 'flash-attn do not support attention_mask'
823
+ # with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
824
+ # hidden_states = F.scaled_dot_product_attention(
825
+ # query, key, value, dropout_p=0.0, is_causal=False
826
+ # )
827
+ # elif self.attention_mode == 'xformers':
828
+ # with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
829
+ # hidden_states = F.scaled_dot_product_attention(
830
+ # query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
831
+ # )
832
+
833
+ # Use basic attention implementation
834
+ with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True):
835
+ hidden_states = F.scaled_dot_product_attention(
836
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
837
+ )
838
+
839
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
840
+ hidden_states = hidden_states.to(query.dtype)
841
+
842
+ # linear proj
843
+ hidden_states = attn.to_out[0](hidden_states)
844
+ # dropout
845
+ hidden_states = attn.to_out[1](hidden_states)
846
+
847
+ if input_ndim == 4:
848
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
849
+
850
+ if attn.residual_connection:
851
+ hidden_states = hidden_states + residual
852
+
853
+ hidden_states = hidden_states / attn.rescale_output_factor
854
+
855
+ return hidden_states
856
+
857
+ class FeedForward(nn.Module):
858
+ r"""
859
+ A feed-forward layer.
860
+
861
+ Parameters:
862
+ dim (`int`): The number of channels in the input.
863
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
864
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
865
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
866
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
867
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
868
+ """
869
+
870
+ def __init__(
871
+ self,
872
+ dim: int,
873
+ dim_out: Optional[int] = None,
874
+ mult: int = 4,
875
+ dropout: float = 0.0,
876
+ activation_fn: str = "geglu",
877
+ final_dropout: bool = False,
878
+ ):
879
+ super().__init__()
880
+ inner_dim = int(dim * mult)
881
+ dim_out = dim_out if dim_out is not None else dim
882
+ linear_cls = nn.Linear
883
+
884
+ if activation_fn == "gelu":
885
+ act_fn = GELU(dim, inner_dim)
886
+ if activation_fn == "gelu-approximate":
887
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
888
+ elif activation_fn == "geglu":
889
+ act_fn = GEGLU(dim, inner_dim)
890
+ elif activation_fn == "geglu-approximate":
891
+ act_fn = ApproximateGELU(dim, inner_dim)
892
+
893
+ self.net = nn.ModuleList([])
894
+ # project in
895
+ self.net.append(act_fn)
896
+ # project dropout
897
+ self.net.append(nn.Dropout(dropout))
898
+ # project out
899
+ self.net.append(linear_cls(inner_dim, dim_out))
900
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
901
+ if final_dropout:
902
+ self.net.append(nn.Dropout(dropout))
903
+
904
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
905
+ for module in self.net:
906
+ hidden_states = module(hidden_states)
907
+ return hidden_states
908
+
909
+
910
+ @maybe_allow_in_graph
911
+ class BasicTransformerBlock(nn.Module):
912
+ r"""
913
+ A basic Transformer block.
914
+
915
+ Parameters:
916
+ dim (`int`): The number of channels in the input and output.
917
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
918
+ attention_head_dim (`int`): The number of channels in each head.
919
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
920
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
921
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
922
+ num_embeds_ada_norm (:
923
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
924
+ attention_bias (:
925
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
926
+ only_cross_attention (`bool`, *optional*):
927
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
928
+ double_self_attention (`bool`, *optional*):
929
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
930
+ upcast_attention (`bool`, *optional*):
931
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
932
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
933
+ Whether to use learnable elementwise affine parameters for normalization.
934
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
935
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
936
+ final_dropout (`bool` *optional*, defaults to False):
937
+ Whether to apply a final dropout after the last feed-forward layer.
938
+ positional_embeddings (`str`, *optional*, defaults to `None`):
939
+ The type of positional embeddings to apply to.
940
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
941
+ The maximum number of positional embeddings to apply.
942
+ """
943
+
944
+ def __init__(
945
+ self,
946
+ dim: int,
947
+ num_attention_heads: int,
948
+ attention_head_dim: int,
949
+ dropout=0.0,
950
+ cross_attention_dim: Optional[int] = None,
951
+ activation_fn: str = "geglu",
952
+ num_embeds_ada_norm: Optional[int] = None,
953
+ attention_bias: bool = False,
954
+ only_cross_attention: bool = False,
955
+ double_self_attention: bool = False,
956
+ upcast_attention: bool = False,
957
+ norm_elementwise_affine: bool = True,
958
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
959
+ norm_eps: float = 1e-5,
960
+ final_dropout: bool = False,
961
+ positional_embeddings: Optional[str] = None,
962
+ num_positional_embeddings: Optional[int] = None,
963
+ sa_attention_mode: str = "flash",
964
+ ca_attention_mode: str = "xformers",
965
+ use_rope: bool = False,
966
+ interpolation_scale_thw: Tuple[int] = (1, 1, 1),
967
+ block_idx: Optional[int] = None,
968
+ ):
969
+ super().__init__()
970
+ self.only_cross_attention = only_cross_attention
971
+
972
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
973
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
974
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
975
+ self.use_layer_norm = norm_type == "layer_norm"
976
+
977
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
978
+ raise ValueError(
979
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
980
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
981
+ )
982
+
983
+ if positional_embeddings and (num_positional_embeddings is None):
984
+ raise ValueError(
985
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
986
+ )
987
+
988
+ if positional_embeddings == "sinusoidal":
989
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
990
+ else:
991
+ self.pos_embed = None
992
+
993
+ # Define 3 blocks. Each block has its own normalization layer.
994
+ # 1. Self-Attn
995
+ if self.use_ada_layer_norm:
996
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
997
+ elif self.use_ada_layer_norm_zero:
998
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
999
+ else:
1000
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1001
+
1002
+ self.attn1 = Attention(
1003
+ query_dim=dim,
1004
+ heads=num_attention_heads,
1005
+ dim_head=attention_head_dim,
1006
+ dropout=dropout,
1007
+ bias=attention_bias,
1008
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
1009
+ upcast_attention=upcast_attention,
1010
+ attention_mode=sa_attention_mode,
1011
+ use_rope=use_rope,
1012
+ interpolation_scale_thw=interpolation_scale_thw,
1013
+ )
1014
+
1015
+ # 2. Cross-Attn
1016
+ if cross_attention_dim is not None or double_self_attention:
1017
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
1018
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
1019
+ # the second cross attention block.
1020
+ self.norm2 = (
1021
+ AdaLayerNorm(dim, num_embeds_ada_norm)
1022
+ if self.use_ada_layer_norm
1023
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1024
+ )
1025
+ self.attn2 = Attention(
1026
+ query_dim=dim,
1027
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
1028
+ heads=num_attention_heads,
1029
+ dim_head=attention_head_dim,
1030
+ dropout=dropout,
1031
+ bias=attention_bias,
1032
+ upcast_attention=upcast_attention,
1033
+ attention_mode=ca_attention_mode, # only xformers support attention_mask
1034
+ use_rope=False, # do not position in cross attention
1035
+ interpolation_scale_thw=interpolation_scale_thw,
1036
+ ) # is self-attn if encoder_hidden_states is none
1037
+ else:
1038
+ self.norm2 = None
1039
+ self.attn2 = None
1040
+
1041
+ # 3. Feed-forward
1042
+
1043
+ if not self.use_ada_layer_norm_single:
1044
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1045
+
1046
+ self.ff = FeedForward(
1047
+ dim,
1048
+ dropout=dropout,
1049
+ activation_fn=activation_fn,
1050
+ final_dropout=final_dropout,
1051
+ )
1052
+
1053
+ # 5. Scale-shift for PixArt-Alpha.
1054
+ if self.use_ada_layer_norm_single:
1055
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
1056
+
1057
+
1058
+ def forward(
1059
+ self,
1060
+ hidden_states: torch.FloatTensor,
1061
+ attention_mask: Optional[torch.FloatTensor] = None,
1062
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1063
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1064
+ timestep: Optional[torch.LongTensor] = None,
1065
+ cross_attention_kwargs: Dict[str, Any] = None,
1066
+ class_labels: Optional[torch.LongTensor] = None,
1067
+ frame: int = None,
1068
+ height: int = None,
1069
+ width: int = None,
1070
+ ) -> torch.FloatTensor:
1071
+ # Notice that normalization is always applied before the real computation in the following blocks.
1072
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
1073
+
1074
+ # 0. Self-Attention
1075
+ batch_size = hidden_states.shape[0]
1076
+
1077
+ if self.use_ada_layer_norm:
1078
+ norm_hidden_states = self.norm1(hidden_states, timestep)
1079
+ elif self.use_ada_layer_norm_zero:
1080
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
1081
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
1082
+ )
1083
+ elif self.use_layer_norm:
1084
+ norm_hidden_states = self.norm1(hidden_states)
1085
+ elif self.use_ada_layer_norm_single:
1086
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
1087
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
1088
+ ).chunk(6, dim=1)
1089
+ norm_hidden_states = self.norm1(hidden_states)
1090
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
1091
+ norm_hidden_states = norm_hidden_states.squeeze(1)
1092
+ else:
1093
+ raise ValueError("Incorrect norm used")
1094
+
1095
+ if self.pos_embed is not None:
1096
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1097
+
1098
+ attn_output = self.attn1(
1099
+ norm_hidden_states,
1100
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1101
+ attention_mask=attention_mask,
1102
+ frame=frame,
1103
+ height=height,
1104
+ width=width,
1105
+ **cross_attention_kwargs,
1106
+ )
1107
+ if self.use_ada_layer_norm_zero:
1108
+ attn_output = gate_msa.unsqueeze(1) * attn_output
1109
+ elif self.use_ada_layer_norm_single:
1110
+ attn_output = gate_msa * attn_output
1111
+
1112
+ hidden_states = attn_output + hidden_states
1113
+ if hidden_states.ndim == 4:
1114
+ hidden_states = hidden_states.squeeze(1)
1115
+
1116
+ # 1. Cross-Attention
1117
+ if self.attn2 is not None:
1118
+
1119
+ if self.use_ada_layer_norm:
1120
+ norm_hidden_states = self.norm2(hidden_states, timestep)
1121
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
1122
+ norm_hidden_states = self.norm2(hidden_states)
1123
+ elif self.use_ada_layer_norm_single:
1124
+ # For PixArt norm2 isn't applied here:
1125
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
1126
+ norm_hidden_states = hidden_states
1127
+ else:
1128
+ raise ValueError("Incorrect norm")
1129
+
1130
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
1131
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1132
+
1133
+ attn_output = self.attn2(
1134
+ norm_hidden_states,
1135
+ encoder_hidden_states=encoder_hidden_states,
1136
+ attention_mask=encoder_attention_mask,
1137
+ **cross_attention_kwargs,
1138
+ )
1139
+ hidden_states = attn_output + hidden_states
1140
+
1141
+
1142
+ # 2. Feed-forward
1143
+ if not self.use_ada_layer_norm_single:
1144
+ norm_hidden_states = self.norm3(hidden_states)
1145
+
1146
+ if self.use_ada_layer_norm_zero:
1147
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
1148
+
1149
+ if self.use_ada_layer_norm_single:
1150
+ norm_hidden_states = self.norm2(hidden_states)
1151
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
1152
+
1153
+ ff_output = self.ff(norm_hidden_states)
1154
+
1155
+ if self.use_ada_layer_norm_zero:
1156
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
1157
+ elif self.use_ada_layer_norm_single:
1158
+ ff_output = gate_mlp * ff_output
1159
+
1160
+
1161
+ hidden_states = ff_output + hidden_states
1162
+ if hidden_states.ndim == 4:
1163
+ hidden_states = hidden_states.squeeze(1)
1164
+
1165
+ return hidden_states
1166
+
1167
+
1168
+ class AdaLayerNormSingle(nn.Module):
1169
+ r"""
1170
+ Norm layer adaptive layer norm single (adaLN-single).
1171
+
1172
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
1173
+
1174
+ Parameters:
1175
+ embedding_dim (`int`): The size of each embedding vector.
1176
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
1177
+ """
1178
+
1179
+ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
1180
+ super().__init__()
1181
+
1182
+ self.emb = CombinedTimestepSizeEmbeddings(
1183
+ embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
1184
+ )
1185
+
1186
+ self.silu = nn.SiLU()
1187
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
1188
+
1189
+ def forward(
1190
+ self,
1191
+ timestep: torch.Tensor,
1192
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
1193
+ batch_size: int = None,
1194
+ hidden_dtype: Optional[torch.dtype] = None,
1195
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1196
+ # No modulation happening here.
1197
+ embedded_timestep = self.emb(
1198
+ timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None
1199
+ )
1200
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep