Pusheen commited on
Commit
af37dce
1 Parent(s): 554a086

Upload 22 files

Browse files
DejaVuSansMono.ttf ADDED
Binary file (341 kB). View file
 
conf/unet/config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.6.0",
4
+ "act_fn": "silu",
5
+ "attention_head_dim": 8,
6
+ "block_out_channels": [
7
+ 320,
8
+ 640,
9
+ 1280,
10
+ 1280
11
+ ],
12
+ "center_input_sample": false,
13
+ "cross_attention_dim": 768,
14
+ "down_block_types": [
15
+ "CrossAttnDownBlock2D",
16
+ "CrossAttnDownBlock2D",
17
+ "CrossAttnDownBlock2D",
18
+ "DownBlock2D"
19
+ ],
20
+ "downsample_padding": 1,
21
+ "flip_sin_to_cos": true,
22
+ "freq_shift": 0,
23
+ "in_channels": 4,
24
+ "layers_per_block": 2,
25
+ "mid_block_scale_factor": 1,
26
+ "norm_eps": 1e-05,
27
+ "norm_num_groups": 32,
28
+ "out_channels": 4,
29
+ "sample_size": 64,
30
+ "up_block_types": [
31
+ "UpBlock2D",
32
+ "CrossAttnUpBlock2D",
33
+ "CrossAttnUpBlock2D",
34
+ "CrossAttnUpBlock2D"
35
+ ]
36
+ }
images/hello_kitty_results.png ADDED
images/input.png ADDED
my_model/__init__.py ADDED
File without changes
my_model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (146 Bytes). View file
 
my_model/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (144 Bytes). View file
 
my_model/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (133 Bytes). View file
 
my_model/__pycache__/attention.cpython-310.pyc ADDED
Binary file (22.5 kB). View file
 
my_model/__pycache__/attention.cpython-38.pyc ADDED
Binary file (22.4 kB). View file
 
my_model/__pycache__/attention.cpython-39.pyc ADDED
Binary file (22.3 kB). View file
 
my_model/__pycache__/unet_2d_blocks.cpython-310.pyc ADDED
Binary file (27.3 kB). View file
 
my_model/__pycache__/unet_2d_blocks.cpython-38.pyc ADDED
Binary file (26.7 kB). View file
 
my_model/__pycache__/unet_2d_blocks.cpython-39.pyc ADDED
Binary file (26.6 kB). View file
 
my_model/__pycache__/unet_2d_condition.cpython-310.pyc ADDED
Binary file (10.4 kB). View file
 
my_model/__pycache__/unet_2d_condition.cpython-38.pyc ADDED
Binary file (10.2 kB). View file
 
my_model/__pycache__/unet_2d_condition.cpython-39.pyc ADDED
Binary file (10.2 kB). View file
 
my_model/attention.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from dataclasses import dataclass
16
+ from typing import Optional
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.modeling_utils import ModelMixin
24
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
25
+ from diffusers.utils import BaseOutput
26
+ from diffusers.utils.import_utils import is_xformers_available
27
+
28
+ @dataclass
29
+ class Transformer2DModelOutput(BaseOutput):
30
+ """
31
+ Args:
32
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
33
+ Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
34
+ for the unnoised latent pixels.
35
+ """
36
+
37
+ sample: torch.FloatTensor
38
+
39
+
40
+ if is_xformers_available():
41
+ import xformers
42
+ import xformers.ops
43
+ else:
44
+ xformers = None
45
+
46
+
47
+ class Transformer2DModel(ModelMixin, ConfigMixin):
48
+ """
49
+ Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
50
+ embeddings) inputs_coarse.
51
+
52
+ When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
53
+ transformer action. Finally, reshape to image.
54
+
55
+ When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
56
+ embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
57
+ classes of unnoised image.
58
+
59
+ Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
60
+ image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
61
+
62
+ Parameters:
63
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
64
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
65
+ in_channels (`int`, *optional*):
66
+ Pass if the input is continuous. The number of channels in the input and output.
67
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
68
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
69
+ cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
70
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
71
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
72
+ `ImagePositionalEmbeddings`.
73
+ num_vector_embeds (`int`, *optional*):
74
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
75
+ Includes the class for the masked latent pixel.
76
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
77
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
78
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
79
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
80
+ up to but not more than steps than `num_embeds_ada_norm`.
81
+ attention_bias (`bool`, *optional*):
82
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
83
+ """
84
+
85
+ @register_to_config
86
+ def __init__(
87
+ self,
88
+ num_attention_heads: int = 16,
89
+ attention_head_dim: int = 88,
90
+ in_channels: Optional[int] = None,
91
+ num_layers: int = 1,
92
+ dropout: float = 0.0,
93
+ norm_num_groups: int = 32,
94
+ cross_attention_dim: Optional[int] = None,
95
+ attention_bias: bool = False,
96
+ sample_size: Optional[int] = None,
97
+ num_vector_embeds: Optional[int] = None,
98
+ activation_fn: str = "geglu",
99
+ num_embeds_ada_norm: Optional[int] = None,
100
+ ):
101
+ super().__init__()
102
+ self.num_attention_heads = num_attention_heads
103
+ self.attention_head_dim = attention_head_dim
104
+ inner_dim = num_attention_heads * attention_head_dim
105
+
106
+ # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
107
+ # Define whether input is continuous or discrete depending on configuration
108
+ self.is_input_continuous = in_channels is not None
109
+ self.is_input_vectorized = num_vector_embeds is not None
110
+
111
+ if self.is_input_continuous and self.is_input_vectorized:
112
+ raise ValueError(
113
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
114
+ " sure that either `in_channels` or `num_vector_embeds` is None."
115
+ )
116
+ elif not self.is_input_continuous and not self.is_input_vectorized:
117
+ raise ValueError(
118
+ f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
119
+ " sure that either `in_channels` or `num_vector_embeds` is not None."
120
+ )
121
+
122
+ # 2. Define input layers
123
+ if self.is_input_continuous:
124
+ self.in_channels = in_channels
125
+
126
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
127
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
128
+ elif self.is_input_vectorized:
129
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
130
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
131
+
132
+ self.height = sample_size
133
+ self.width = sample_size
134
+ self.num_vector_embeds = num_vector_embeds
135
+ self.num_latent_pixels = self.height * self.width
136
+
137
+ self.latent_image_embedding = ImagePositionalEmbeddings(
138
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
139
+ )
140
+
141
+ # 3. Define transformers blocks
142
+ self.transformer_blocks = nn.ModuleList(
143
+ [
144
+ BasicTransformerBlock(
145
+ inner_dim,
146
+ num_attention_heads,
147
+ attention_head_dim,
148
+ dropout=dropout,
149
+ cross_attention_dim=cross_attention_dim,
150
+ activation_fn=activation_fn,
151
+ num_embeds_ada_norm=num_embeds_ada_norm,
152
+ attention_bias=attention_bias,
153
+ )
154
+ for d in range(num_layers)
155
+ ]
156
+ )
157
+
158
+ # 4. Define output layers
159
+ if self.is_input_continuous:
160
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
161
+ elif self.is_input_vectorized:
162
+ self.norm_out = nn.LayerNorm(inner_dim)
163
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
164
+
165
+ def _set_attention_slice(self, slice_size):
166
+ for block in self.transformer_blocks:
167
+ block._set_attention_slice(slice_size)
168
+
169
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attn_map=None, attn_shift=False, obj_ids=None, relationship=None, return_dict: bool = True):
170
+ """
171
+ Args:
172
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
173
+ When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
174
+ hidden_states
175
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
176
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
177
+ self-attention.
178
+ timestep ( `torch.long`, *optional*):
179
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
180
+ return_dict (`bool`, *optional*, defaults to `True`):
181
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
182
+
183
+ Returns:
184
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
185
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
186
+ tensor.
187
+ """
188
+ # 1. Input
189
+ if self.is_input_continuous:
190
+ batch, channel, height, weight = hidden_states.shape
191
+ residual = hidden_states
192
+ hidden_states = self.norm(hidden_states)
193
+ hidden_states = self.proj_in(hidden_states)
194
+ inner_dim = hidden_states.shape[1]
195
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
196
+ elif self.is_input_vectorized:
197
+ hidden_states = self.latent_image_embedding(hidden_states)
198
+
199
+ # 2. Blocks
200
+ for block in self.transformer_blocks:
201
+ hidden_states, cross_attn_prob = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
202
+
203
+ # 3. Output
204
+ if self.is_input_continuous:
205
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
206
+ hidden_states = self.proj_out(hidden_states)
207
+ output = hidden_states + residual
208
+ elif self.is_input_vectorized:
209
+ hidden_states = self.norm_out(hidden_states)
210
+ logits = self.out(hidden_states)
211
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
212
+ logits = logits.permute(0, 2, 1)
213
+
214
+ # log(p(x_0))
215
+ output = F.log_softmax(logits.double(), dim=1).float()
216
+
217
+ if not return_dict:
218
+ return (output,)
219
+
220
+ return Transformer2DModelOutput(sample=output), cross_attn_prob
221
+
222
+ def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
223
+ for block in self.transformer_blocks:
224
+ block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
225
+
226
+
227
+ class AttentionBlock(nn.Module):
228
+ """
229
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
230
+ to the N-d case.
231
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
232
+ Uses three q, k, v linear layers to compute attention.
233
+
234
+ Parameters:
235
+ channels (`int`): The number of channels in the input and output.
236
+ num_head_channels (`int`, *optional*):
237
+ The number of channels in each head. If None, then `num_heads` = 1.
238
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
239
+ rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
240
+ eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ channels: int,
246
+ num_head_channels: Optional[int] = None,
247
+ norm_num_groups: int = 32,
248
+ rescale_output_factor: float = 1.0,
249
+ eps: float = 1e-5,
250
+ ):
251
+ super().__init__()
252
+ self.channels = channels
253
+
254
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
255
+ self.num_head_size = num_head_channels
256
+ self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
257
+
258
+ # define q,k,v as linear layers
259
+ self.query = nn.Linear(channels, channels)
260
+ self.key = nn.Linear(channels, channels)
261
+ self.value = nn.Linear(channels, channels)
262
+
263
+ self.rescale_output_factor = rescale_output_factor
264
+ self.proj_attn = nn.Linear(channels, channels, 1)
265
+
266
+ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
267
+ new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
268
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
269
+ new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
270
+ return new_projection
271
+
272
+ def forward(self, hidden_states):
273
+ residual = hidden_states
274
+ batch, channel, height, width = hidden_states.shape
275
+
276
+ # norm
277
+ hidden_states = self.group_norm(hidden_states)
278
+
279
+ hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
280
+
281
+ # proj to q, k, v
282
+ query_proj = self.query(hidden_states)
283
+ key_proj = self.key(hidden_states)
284
+ value_proj = self.value(hidden_states)
285
+
286
+ # transpose
287
+ query_states = self.transpose_for_scores(query_proj)
288
+ key_states = self.transpose_for_scores(key_proj)
289
+ value_states = self.transpose_for_scores(value_proj)
290
+
291
+ # get scores
292
+ scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
293
+ attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
294
+ attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
295
+
296
+ # compute attention output
297
+ hidden_states = torch.matmul(attention_probs, value_states)
298
+
299
+ hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
300
+ new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
301
+ hidden_states = hidden_states.view(new_hidden_states_shape)
302
+
303
+ # compute next hidden_states
304
+ hidden_states = self.proj_attn(hidden_states)
305
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
306
+
307
+ # res connect and rescale
308
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
309
+ return hidden_states
310
+
311
+
312
+ class BasicTransformerBlock(nn.Module):
313
+ r"""
314
+ A basic Transformer block.
315
+
316
+ Parameters:
317
+ dim (`int`): The number of channels in the input and output.
318
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
319
+ attention_head_dim (`int`): The number of channels in each head.
320
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
321
+ cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
322
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
323
+ num_embeds_ada_norm (:
324
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
325
+ attention_bias (:
326
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
327
+ """
328
+
329
+ def __init__(
330
+ self,
331
+ dim: int,
332
+ num_attention_heads: int,
333
+ attention_head_dim: int,
334
+ dropout=0.0,
335
+ cross_attention_dim: Optional[int] = None,
336
+ activation_fn: str = "geglu",
337
+ num_embeds_ada_norm: Optional[int] = None,
338
+ attention_bias: bool = False,
339
+ ):
340
+ super().__init__()
341
+ self.attn1 = CrossAttention(
342
+ query_dim=dim,
343
+ heads=num_attention_heads,
344
+ dim_head=attention_head_dim,
345
+ dropout=dropout,
346
+ bias=attention_bias,
347
+ ) # is a self-attention
348
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
349
+ self.attn2 = CrossAttention(
350
+ query_dim=dim,
351
+ cross_attention_dim=cross_attention_dim,
352
+ heads=num_attention_heads,
353
+ dim_head=attention_head_dim,
354
+ dropout=dropout,
355
+ bias=attention_bias,
356
+ ) # is self-attn if context is none
357
+
358
+ # layer norms
359
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
360
+ if self.use_ada_layer_norm:
361
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
362
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
363
+ else:
364
+ self.norm1 = nn.LayerNorm(dim)
365
+ self.norm2 = nn.LayerNorm(dim)
366
+ self.norm3 = nn.LayerNorm(dim)
367
+
368
+ def _set_attention_slice(self, slice_size):
369
+ self.attn1._slice_size = slice_size
370
+ self.attn2._slice_size = slice_size
371
+
372
+ def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
373
+ if not is_xformers_available():
374
+ print("Here is how to install it")
375
+ raise ModuleNotFoundError(
376
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
377
+ " xformers",
378
+ name="xformers",
379
+ )
380
+ elif not torch.cuda.is_available():
381
+ raise ValueError(
382
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
383
+ " available for GPU "
384
+ )
385
+ else:
386
+ try:
387
+ # Make sure we can run the memory efficient attention
388
+ _ = xformers.ops.memory_efficient_attention(
389
+ torch.randn((1, 2, 40), device="cuda"),
390
+ torch.randn((1, 2, 40), device="cuda"),
391
+ torch.randn((1, 2, 40), device="cuda"),
392
+ )
393
+ except Exception as e:
394
+ raise e
395
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
396
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
397
+
398
+ def forward(self, hidden_states, context=None, timestep=None):
399
+ # 1. Self-Attention
400
+ norm_hidden_states = (
401
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
402
+ )
403
+ tmp_hidden_states, cross_attn_prob = self.attn1(norm_hidden_states)
404
+ hidden_states = tmp_hidden_states + hidden_states
405
+
406
+ # 2. Cross-Attention
407
+ norm_hidden_states = (
408
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
409
+ )
410
+ tmp_hidden_states, cross_attn_prob = self.attn2(norm_hidden_states, context=context)
411
+ hidden_states = tmp_hidden_states + hidden_states
412
+
413
+ # 3. Feed-forward
414
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
415
+
416
+ return hidden_states, cross_attn_prob
417
+
418
+
419
+ class CrossAttention(nn.Module):
420
+ r"""
421
+ A cross attention layer.
422
+
423
+ Parameters:
424
+ query_dim (`int`): The number of channels in the query.
425
+ cross_attention_dim (`int`, *optional*):
426
+ The number of channels in the context. If not given, defaults to `query_dim`.
427
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
428
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
429
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
430
+ bias (`bool`, *optional*, defaults to False):
431
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
432
+ """
433
+
434
+ def __init__(
435
+ self,
436
+ query_dim: int,
437
+ cross_attention_dim: Optional[int] = None,
438
+ heads: int = 8,
439
+ dim_head: int = 64,
440
+ dropout: float = 0.0,
441
+ bias=False,
442
+ ):
443
+ super().__init__()
444
+ inner_dim = dim_head * heads
445
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
446
+
447
+ self.scale = dim_head**-0.5
448
+ self.heads = heads
449
+ # for slice_size > 0 the attention score computation
450
+ # is split across the batch axis to save memory
451
+ # You can set slice_size with `set_attention_slice`
452
+ self._slice_size = None
453
+ self._use_memory_efficient_attention_xformers = False
454
+
455
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
456
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
457
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
458
+
459
+ self.to_out = nn.ModuleList([])
460
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
461
+ self.to_out.append(nn.Dropout(dropout))
462
+
463
+ def reshape_heads_to_batch_dim(self, tensor):
464
+ batch_size, seq_len, dim = tensor.shape
465
+ head_size = self.heads
466
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
467
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
468
+ return tensor
469
+
470
+ def reshape_batch_dim_to_heads(self, tensor):
471
+ batch_size, seq_len, dim = tensor.shape
472
+ head_size = self.heads
473
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
474
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
475
+ return tensor
476
+
477
+ def forward(self, hidden_states, context=None, mask=None):
478
+ batch_size, sequence_length, _ = hidden_states.shape
479
+
480
+ query = self.to_q(hidden_states)
481
+ context = context if context is not None else hidden_states
482
+ key = self.to_k(context)
483
+ value = self.to_v(context)
484
+
485
+ dim = query.shape[-1]
486
+
487
+ query = self.reshape_heads_to_batch_dim(query)
488
+ key = self.reshape_heads_to_batch_dim(key)
489
+ value = self.reshape_heads_to_batch_dim(value)
490
+
491
+ # TODO(PVP) - mask is currently never used. Remember to re-implement when used
492
+
493
+ # attention, what we cannot get enough of
494
+ if self._use_memory_efficient_attention_xformers:
495
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value)
496
+ else:
497
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
498
+ hidden_states, attention_probs = self._attention(query, key, value)
499
+ else:
500
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
501
+
502
+ # linear proj
503
+ hidden_states = self.to_out[0](hidden_states)
504
+ # dropout
505
+ hidden_states = self.to_out[1](hidden_states)
506
+ return hidden_states, attention_probs
507
+
508
+ def _attention(self, query, key, value):
509
+ # TODO: use baddbmm for better performance
510
+ if query.device.type == "mps":
511
+ # Better performance on mps (~20-25%)
512
+ attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale
513
+ else:
514
+ attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
515
+ attention_probs = attention_scores.softmax(dim=-1)
516
+ # compute attention output
517
+
518
+ if query.device.type == "mps":
519
+ hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value)
520
+ else:
521
+ hidden_states = torch.matmul(attention_probs, value)
522
+
523
+ # reshape hidden_states
524
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
525
+ return hidden_states, attention_probs
526
+
527
+ def _sliced_attention(self, query, key, value, sequence_length, dim):
528
+ batch_size_attention = query.shape[0]
529
+ hidden_states = torch.zeros(
530
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
531
+ )
532
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
533
+ for i in range(hidden_states.shape[0] // slice_size):
534
+ start_idx = i * slice_size
535
+ end_idx = (i + 1) * slice_size
536
+ if query.device.type == "mps":
537
+ # Better performance on mps (~20-25%)
538
+ attn_slice = (
539
+ torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx])
540
+ * self.scale
541
+ )
542
+ else:
543
+ attn_slice = (
544
+ torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
545
+ ) # TODO: use baddbmm for better performance
546
+ attn_slice = attn_slice.softmax(dim=-1)
547
+ if query.device.type == "mps":
548
+ attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
549
+ else:
550
+ attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
551
+
552
+ hidden_states[start_idx:end_idx] = attn_slice
553
+
554
+ # reshape hidden_states
555
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
556
+ return hidden_states
557
+
558
+ def _memory_efficient_attention_xformers(self, query, key, value):
559
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
560
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
561
+ return hidden_states
562
+
563
+
564
+ class FeedForward(nn.Module):
565
+ r"""
566
+ A feed-forward layer.
567
+
568
+ Parameters:
569
+ dim (`int`): The number of channels in the input.
570
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
571
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
572
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
573
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
574
+ """
575
+
576
+ def __init__(
577
+ self,
578
+ dim: int,
579
+ dim_out: Optional[int] = None,
580
+ mult: int = 4,
581
+ dropout: float = 0.0,
582
+ activation_fn: str = "geglu",
583
+ ):
584
+ super().__init__()
585
+ inner_dim = int(dim * mult)
586
+ dim_out = dim_out if dim_out is not None else dim
587
+
588
+ if activation_fn == "geglu":
589
+ geglu = GEGLU(dim, inner_dim)
590
+ elif activation_fn == "geglu-approximate":
591
+ geglu = ApproximateGELU(dim, inner_dim)
592
+
593
+ self.net = nn.ModuleList([])
594
+ # project in
595
+ self.net.append(geglu)
596
+ # project dropout
597
+ self.net.append(nn.Dropout(dropout))
598
+ # project out
599
+ self.net.append(nn.Linear(inner_dim, dim_out))
600
+
601
+ def forward(self, hidden_states):
602
+ for module in self.net:
603
+ hidden_states = module(hidden_states)
604
+ return hidden_states
605
+
606
+
607
+ # feedforward
608
+ class GEGLU(nn.Module):
609
+ r"""
610
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
611
+
612
+ Parameters:
613
+ dim_in (`int`): The number of channels in the input.
614
+ dim_out (`int`): The number of channels in the output.
615
+ """
616
+
617
+ def __init__(self, dim_in: int, dim_out: int):
618
+ super().__init__()
619
+ self.proj = nn.Linear(dim_in, dim_out * 2)
620
+
621
+ def gelu(self, gate):
622
+ if gate.device.type != "mps":
623
+ return F.gelu(gate)
624
+ # mps: gelu is not implemented for float16
625
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
626
+
627
+ def forward(self, hidden_states):
628
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
629
+ return hidden_states * self.gelu(gate)
630
+
631
+
632
+ class ApproximateGELU(nn.Module):
633
+ """
634
+ The approximate form of Gaussian Error Linear Unit (GELU)
635
+
636
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
637
+ """
638
+
639
+ def __init__(self, dim_in: int, dim_out: int):
640
+ super().__init__()
641
+ self.proj = nn.Linear(dim_in, dim_out)
642
+
643
+ def forward(self, x):
644
+ x = self.proj(x)
645
+ return x * torch.sigmoid(1.702 * x)
646
+
647
+
648
+ class AdaLayerNorm(nn.Module):
649
+ """
650
+ Norm layer modified to incorporate timestep embeddings.
651
+ """
652
+
653
+ def __init__(self, embedding_dim, num_embeddings):
654
+ super().__init__()
655
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
656
+ self.silu = nn.SiLU()
657
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
658
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
659
+
660
+ def forward(self, x, timestep):
661
+ emb = self.linear(self.silu(self.emb(timestep)))
662
+ scale, shift = torch.chunk(emb, 2)
663
+ x = self.norm(x) * (1 + scale) + shift
664
+ return x
my_model/unet_2d_blocks.py ADDED
@@ -0,0 +1,1602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import torch
16
+ from torch import nn
17
+
18
+ from .attention import AttentionBlock, Transformer2DModel
19
+ from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
20
+
21
+
22
+ def get_down_block(
23
+ down_block_type,
24
+ num_layers,
25
+ in_channels,
26
+ out_channels,
27
+ temb_channels,
28
+ add_downsample,
29
+ resnet_eps,
30
+ resnet_act_fn,
31
+ attn_num_head_channels,
32
+ resnet_groups=None,
33
+ cross_attention_dim=None,
34
+ downsample_padding=None,
35
+ ):
36
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
37
+ if down_block_type == "DownBlock2D":
38
+ return DownBlock2D(
39
+ num_layers=num_layers,
40
+ in_channels=in_channels,
41
+ out_channels=out_channels,
42
+ temb_channels=temb_channels,
43
+ add_downsample=add_downsample,
44
+ resnet_eps=resnet_eps,
45
+ resnet_act_fn=resnet_act_fn,
46
+ resnet_groups=resnet_groups,
47
+ downsample_padding=downsample_padding,
48
+ )
49
+ elif down_block_type == "AttnDownBlock2D":
50
+ return AttnDownBlock2D(
51
+ num_layers=num_layers,
52
+ in_channels=in_channels,
53
+ out_channels=out_channels,
54
+ temb_channels=temb_channels,
55
+ add_downsample=add_downsample,
56
+ resnet_eps=resnet_eps,
57
+ resnet_act_fn=resnet_act_fn,
58
+ resnet_groups=resnet_groups,
59
+ downsample_padding=downsample_padding,
60
+ attn_num_head_channels=attn_num_head_channels,
61
+ )
62
+ elif down_block_type == "CrossAttnDownBlock2D":
63
+ if cross_attention_dim is None:
64
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
65
+ return CrossAttnDownBlock2D(
66
+ num_layers=num_layers,
67
+ in_channels=in_channels,
68
+ out_channels=out_channels,
69
+ temb_channels=temb_channels,
70
+ add_downsample=add_downsample,
71
+ resnet_eps=resnet_eps,
72
+ resnet_act_fn=resnet_act_fn,
73
+ resnet_groups=resnet_groups,
74
+ downsample_padding=downsample_padding,
75
+ cross_attention_dim=cross_attention_dim,
76
+ attn_num_head_channels=attn_num_head_channels,
77
+ )
78
+ elif down_block_type == "SkipDownBlock2D":
79
+ return SkipDownBlock2D(
80
+ num_layers=num_layers,
81
+ in_channels=in_channels,
82
+ out_channels=out_channels,
83
+ temb_channels=temb_channels,
84
+ add_downsample=add_downsample,
85
+ resnet_eps=resnet_eps,
86
+ resnet_act_fn=resnet_act_fn,
87
+ downsample_padding=downsample_padding,
88
+ )
89
+ elif down_block_type == "AttnSkipDownBlock2D":
90
+ return AttnSkipDownBlock2D(
91
+ num_layers=num_layers,
92
+ in_channels=in_channels,
93
+ out_channels=out_channels,
94
+ temb_channels=temb_channels,
95
+ add_downsample=add_downsample,
96
+ resnet_eps=resnet_eps,
97
+ resnet_act_fn=resnet_act_fn,
98
+ downsample_padding=downsample_padding,
99
+ attn_num_head_channels=attn_num_head_channels,
100
+ )
101
+ elif down_block_type == "DownEncoderBlock2D":
102
+ return DownEncoderBlock2D(
103
+ num_layers=num_layers,
104
+ in_channels=in_channels,
105
+ out_channels=out_channels,
106
+ add_downsample=add_downsample,
107
+ resnet_eps=resnet_eps,
108
+ resnet_act_fn=resnet_act_fn,
109
+ resnet_groups=resnet_groups,
110
+ downsample_padding=downsample_padding,
111
+ )
112
+ elif down_block_type == "AttnDownEncoderBlock2D":
113
+ return AttnDownEncoderBlock2D(
114
+ num_layers=num_layers,
115
+ in_channels=in_channels,
116
+ out_channels=out_channels,
117
+ add_downsample=add_downsample,
118
+ resnet_eps=resnet_eps,
119
+ resnet_act_fn=resnet_act_fn,
120
+ resnet_groups=resnet_groups,
121
+ downsample_padding=downsample_padding,
122
+ attn_num_head_channels=attn_num_head_channels,
123
+ )
124
+ raise ValueError(f"{down_block_type} does not exist.")
125
+
126
+
127
+ def get_up_block(
128
+ up_block_type,
129
+ num_layers,
130
+ in_channels,
131
+ out_channels,
132
+ prev_output_channel,
133
+ temb_channels,
134
+ add_upsample,
135
+ resnet_eps,
136
+ resnet_act_fn,
137
+ attn_num_head_channels,
138
+ resnet_groups=None,
139
+ cross_attention_dim=None,
140
+ ):
141
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
142
+ if up_block_type == "UpBlock2D":
143
+ return UpBlock2D(
144
+ num_layers=num_layers,
145
+ in_channels=in_channels,
146
+ out_channels=out_channels,
147
+ prev_output_channel=prev_output_channel,
148
+ temb_channels=temb_channels,
149
+ add_upsample=add_upsample,
150
+ resnet_eps=resnet_eps,
151
+ resnet_act_fn=resnet_act_fn,
152
+ resnet_groups=resnet_groups,
153
+ )
154
+ elif up_block_type == "CrossAttnUpBlock2D":
155
+ if cross_attention_dim is None:
156
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
157
+ return CrossAttnUpBlock2D(
158
+ num_layers=num_layers,
159
+ in_channels=in_channels,
160
+ out_channels=out_channels,
161
+ prev_output_channel=prev_output_channel,
162
+ temb_channels=temb_channels,
163
+ add_upsample=add_upsample,
164
+ resnet_eps=resnet_eps,
165
+ resnet_act_fn=resnet_act_fn,
166
+ resnet_groups=resnet_groups,
167
+ cross_attention_dim=cross_attention_dim,
168
+ attn_num_head_channels=attn_num_head_channels,
169
+ )
170
+ elif up_block_type == "AttnUpBlock2D":
171
+ return AttnUpBlock2D(
172
+ num_layers=num_layers,
173
+ in_channels=in_channels,
174
+ out_channels=out_channels,
175
+ prev_output_channel=prev_output_channel,
176
+ temb_channels=temb_channels,
177
+ add_upsample=add_upsample,
178
+ resnet_eps=resnet_eps,
179
+ resnet_act_fn=resnet_act_fn,
180
+ resnet_groups=resnet_groups,
181
+ attn_num_head_channels=attn_num_head_channels,
182
+ )
183
+ elif up_block_type == "SkipUpBlock2D":
184
+ return SkipUpBlock2D(
185
+ num_layers=num_layers,
186
+ in_channels=in_channels,
187
+ out_channels=out_channels,
188
+ prev_output_channel=prev_output_channel,
189
+ temb_channels=temb_channels,
190
+ add_upsample=add_upsample,
191
+ resnet_eps=resnet_eps,
192
+ resnet_act_fn=resnet_act_fn,
193
+ )
194
+ elif up_block_type == "AttnSkipUpBlock2D":
195
+ return AttnSkipUpBlock2D(
196
+ num_layers=num_layers,
197
+ in_channels=in_channels,
198
+ out_channels=out_channels,
199
+ prev_output_channel=prev_output_channel,
200
+ temb_channels=temb_channels,
201
+ add_upsample=add_upsample,
202
+ resnet_eps=resnet_eps,
203
+ resnet_act_fn=resnet_act_fn,
204
+ attn_num_head_channels=attn_num_head_channels,
205
+ )
206
+ elif up_block_type == "UpDecoderBlock2D":
207
+ return UpDecoderBlock2D(
208
+ num_layers=num_layers,
209
+ in_channels=in_channels,
210
+ out_channels=out_channels,
211
+ add_upsample=add_upsample,
212
+ resnet_eps=resnet_eps,
213
+ resnet_act_fn=resnet_act_fn,
214
+ resnet_groups=resnet_groups,
215
+ )
216
+ elif up_block_type == "AttnUpDecoderBlock2D":
217
+ return AttnUpDecoderBlock2D(
218
+ num_layers=num_layers,
219
+ in_channels=in_channels,
220
+ out_channels=out_channels,
221
+ add_upsample=add_upsample,
222
+ resnet_eps=resnet_eps,
223
+ resnet_act_fn=resnet_act_fn,
224
+ resnet_groups=resnet_groups,
225
+ attn_num_head_channels=attn_num_head_channels,
226
+ )
227
+ raise ValueError(f"{up_block_type} does not exist.")
228
+
229
+
230
+ class UNetMidBlock2D(nn.Module):
231
+ def __init__(
232
+ self,
233
+ in_channels: int,
234
+ temb_channels: int,
235
+ dropout: float = 0.0,
236
+ num_layers: int = 1,
237
+ resnet_eps: float = 1e-6,
238
+ resnet_time_scale_shift: str = "default",
239
+ resnet_act_fn: str = "swish",
240
+ resnet_groups: int = 32,
241
+ resnet_pre_norm: bool = True,
242
+ attn_num_head_channels=1,
243
+ attention_type="default",
244
+ output_scale_factor=1.0,
245
+ **kwargs,
246
+ ):
247
+ super().__init__()
248
+
249
+ self.attention_type = attention_type
250
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
251
+
252
+ # there is always at least one resnet
253
+ resnets = [
254
+ ResnetBlock2D(
255
+ in_channels=in_channels,
256
+ out_channels=in_channels,
257
+ temb_channels=temb_channels,
258
+ eps=resnet_eps,
259
+ groups=resnet_groups,
260
+ dropout=dropout,
261
+ time_embedding_norm=resnet_time_scale_shift,
262
+ non_linearity=resnet_act_fn,
263
+ output_scale_factor=output_scale_factor,
264
+ pre_norm=resnet_pre_norm,
265
+ )
266
+ ]
267
+ attentions = []
268
+
269
+ for _ in range(num_layers):
270
+ attentions.append(
271
+ AttentionBlock(
272
+ in_channels,
273
+ num_head_channels=attn_num_head_channels,
274
+ rescale_output_factor=output_scale_factor,
275
+ eps=resnet_eps,
276
+ norm_num_groups=resnet_groups,
277
+ )
278
+ )
279
+ resnets.append(
280
+ ResnetBlock2D(
281
+ in_channels=in_channels,
282
+ out_channels=in_channels,
283
+ temb_channels=temb_channels,
284
+ eps=resnet_eps,
285
+ groups=resnet_groups,
286
+ dropout=dropout,
287
+ time_embedding_norm=resnet_time_scale_shift,
288
+ non_linearity=resnet_act_fn,
289
+ output_scale_factor=output_scale_factor,
290
+ pre_norm=resnet_pre_norm,
291
+ )
292
+ )
293
+
294
+ self.attentions = nn.ModuleList(attentions)
295
+ self.resnets = nn.ModuleList(resnets)
296
+
297
+ def forward(self, hidden_states, temb=None, encoder_states=None):
298
+ hidden_states = self.resnets[0](hidden_states, temb)
299
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
300
+ if self.attention_type == "default":
301
+ hidden_states = attn(hidden_states)
302
+ else:
303
+ hidden_states = attn(hidden_states, encoder_states)
304
+ hidden_states = resnet(hidden_states, temb)
305
+
306
+ return hidden_states
307
+
308
+
309
+ class UNetMidBlock2DCrossAttn(nn.Module):
310
+ def __init__(
311
+ self,
312
+ in_channels: int,
313
+ temb_channels: int,
314
+ dropout: float = 0.0,
315
+ num_layers: int = 1,
316
+ resnet_eps: float = 1e-6,
317
+ resnet_time_scale_shift: str = "default",
318
+ resnet_act_fn: str = "swish",
319
+ resnet_groups: int = 32,
320
+ resnet_pre_norm: bool = True,
321
+ attn_num_head_channels=1,
322
+ attention_type="default",
323
+ output_scale_factor=1.0,
324
+ cross_attention_dim=1280,
325
+ **kwargs,
326
+ ):
327
+ super().__init__()
328
+
329
+ self.attention_type = attention_type
330
+ self.attn_num_head_channels = attn_num_head_channels
331
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
332
+
333
+ # there is always at least one resnet
334
+ resnets = [
335
+ ResnetBlock2D(
336
+ in_channels=in_channels,
337
+ out_channels=in_channels,
338
+ temb_channels=temb_channels,
339
+ eps=resnet_eps,
340
+ groups=resnet_groups,
341
+ dropout=dropout,
342
+ time_embedding_norm=resnet_time_scale_shift,
343
+ non_linearity=resnet_act_fn,
344
+ output_scale_factor=output_scale_factor,
345
+ pre_norm=resnet_pre_norm,
346
+ )
347
+ ]
348
+ attentions = []
349
+
350
+ for _ in range(num_layers):
351
+ attentions.append(
352
+ Transformer2DModel(
353
+ attn_num_head_channels,
354
+ in_channels // attn_num_head_channels,
355
+ in_channels=in_channels,
356
+ num_layers=1,
357
+ cross_attention_dim=cross_attention_dim,
358
+ norm_num_groups=resnet_groups,
359
+ )
360
+ )
361
+ resnets.append(
362
+ ResnetBlock2D(
363
+ in_channels=in_channels,
364
+ out_channels=in_channels,
365
+ temb_channels=temb_channels,
366
+ eps=resnet_eps,
367
+ groups=resnet_groups,
368
+ dropout=dropout,
369
+ time_embedding_norm=resnet_time_scale_shift,
370
+ non_linearity=resnet_act_fn,
371
+ output_scale_factor=output_scale_factor,
372
+ pre_norm=resnet_pre_norm,
373
+ )
374
+ )
375
+
376
+ self.attentions = nn.ModuleList(attentions)
377
+ self.resnets = nn.ModuleList(resnets)
378
+
379
+ def set_attention_slice(self, slice_size):
380
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
381
+ raise ValueError(
382
+ f"Make sure slice_size {slice_size} is a divisor of "
383
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
384
+ )
385
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
386
+ raise ValueError(
387
+ f"Chunk_size {slice_size} has to be smaller or equal to "
388
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
389
+ )
390
+
391
+ for attn in self.attentions:
392
+ attn._set_attention_slice(slice_size)
393
+
394
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
395
+ for attn in self.attentions:
396
+ attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
397
+
398
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
399
+ hidden_states = self.resnets[0](hidden_states, temb)
400
+ mid_attn = []
401
+ for layer_idx, (attn, resnet) in enumerate(zip(self.attentions, self.resnets[1:])):
402
+ hidden_states, cross_attn_prob = attn(hidden_states, encoder_hidden_states)
403
+ hidden_states = hidden_states.sample
404
+ hidden_states = resnet(hidden_states, temb)
405
+ mid_attn.append(cross_attn_prob)
406
+ return hidden_states, mid_attn
407
+
408
+
409
+ class AttnDownBlock2D(nn.Module):
410
+ def __init__(
411
+ self,
412
+ in_channels: int,
413
+ out_channels: int,
414
+ temb_channels: int,
415
+ dropout: float = 0.0,
416
+ num_layers: int = 1,
417
+ resnet_eps: float = 1e-6,
418
+ resnet_time_scale_shift: str = "default",
419
+ resnet_act_fn: str = "swish",
420
+ resnet_groups: int = 32,
421
+ resnet_pre_norm: bool = True,
422
+ attn_num_head_channels=1,
423
+ attention_type="default",
424
+ output_scale_factor=1.0,
425
+ downsample_padding=1,
426
+ add_downsample=True,
427
+ ):
428
+ super().__init__()
429
+ resnets = []
430
+ attentions = []
431
+
432
+ self.attention_type = attention_type
433
+
434
+ for i in range(num_layers):
435
+ in_channels = in_channels if i == 0 else out_channels
436
+ resnets.append(
437
+ ResnetBlock2D(
438
+ in_channels=in_channels,
439
+ out_channels=out_channels,
440
+ temb_channels=temb_channels,
441
+ eps=resnet_eps,
442
+ groups=resnet_groups,
443
+ dropout=dropout,
444
+ time_embedding_norm=resnet_time_scale_shift,
445
+ non_linearity=resnet_act_fn,
446
+ output_scale_factor=output_scale_factor,
447
+ pre_norm=resnet_pre_norm,
448
+ )
449
+ )
450
+ attentions.append(
451
+ AttentionBlock(
452
+ out_channels,
453
+ num_head_channels=attn_num_head_channels,
454
+ rescale_output_factor=output_scale_factor,
455
+ eps=resnet_eps,
456
+ norm_num_groups=resnet_groups,
457
+ )
458
+ )
459
+
460
+ self.attentions = nn.ModuleList(attentions)
461
+ self.resnets = nn.ModuleList(resnets)
462
+
463
+ if add_downsample:
464
+ self.downsamplers = nn.ModuleList(
465
+ [
466
+ Downsample2D(
467
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
468
+ )
469
+ ]
470
+ )
471
+ else:
472
+ self.downsamplers = None
473
+
474
+ def forward(self, hidden_states, temb=None):
475
+ output_states = ()
476
+
477
+ for resnet, attn in zip(self.resnets, self.attentions):
478
+ hidden_states = resnet(hidden_states, temb)
479
+ hidden_states = attn(hidden_states)
480
+ output_states += (hidden_states,)
481
+
482
+ if self.downsamplers is not None:
483
+ for downsampler in self.downsamplers:
484
+ hidden_states = downsampler(hidden_states)
485
+
486
+ output_states += (hidden_states,)
487
+
488
+ return hidden_states, output_states
489
+
490
+
491
+ class CrossAttnDownBlock2D(nn.Module):
492
+ def __init__(
493
+ self,
494
+ in_channels: int,
495
+ out_channels: int,
496
+ temb_channels: int,
497
+ dropout: float = 0.0,
498
+ num_layers: int = 1,
499
+ resnet_eps: float = 1e-6,
500
+ resnet_time_scale_shift: str = "default",
501
+ resnet_act_fn: str = "swish",
502
+ resnet_groups: int = 32,
503
+ resnet_pre_norm: bool = True,
504
+ attn_num_head_channels=1,
505
+ cross_attention_dim=1280,
506
+ attention_type="default",
507
+ output_scale_factor=1.0,
508
+ downsample_padding=1,
509
+ add_downsample=True,
510
+ ):
511
+ super().__init__()
512
+ resnets = []
513
+ attentions = []
514
+
515
+ self.attention_type = attention_type
516
+ self.attn_num_head_channels = attn_num_head_channels
517
+
518
+ for i in range(num_layers):
519
+ in_channels = in_channels if i == 0 else out_channels
520
+ resnets.append(
521
+ ResnetBlock2D(
522
+ in_channels=in_channels,
523
+ out_channels=out_channels,
524
+ temb_channels=temb_channels,
525
+ eps=resnet_eps,
526
+ groups=resnet_groups,
527
+ dropout=dropout,
528
+ time_embedding_norm=resnet_time_scale_shift,
529
+ non_linearity=resnet_act_fn,
530
+ output_scale_factor=output_scale_factor,
531
+ pre_norm=resnet_pre_norm,
532
+ )
533
+ )
534
+ attentions.append(
535
+ Transformer2DModel(
536
+ attn_num_head_channels,
537
+ out_channels // attn_num_head_channels,
538
+ in_channels=out_channels,
539
+ num_layers=1,
540
+ cross_attention_dim=cross_attention_dim,
541
+ norm_num_groups=resnet_groups,
542
+ )
543
+ )
544
+ self.attentions = nn.ModuleList(attentions)
545
+ self.resnets = nn.ModuleList(resnets)
546
+
547
+ if add_downsample:
548
+ self.downsamplers = nn.ModuleList(
549
+ [
550
+ Downsample2D(
551
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
552
+ )
553
+ ]
554
+ )
555
+ else:
556
+ self.downsamplers = None
557
+
558
+ self.gradient_checkpointing = False
559
+
560
+ def set_attention_slice(self, slice_size):
561
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
562
+ raise ValueError(
563
+ f"Make sure slice_size {slice_size} is a divisor of "
564
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
565
+ )
566
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
567
+ raise ValueError(
568
+ f"Chunk_size {slice_size} has to be smaller or equal to "
569
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
570
+ )
571
+
572
+ for attn in self.attentions:
573
+ attn._set_attention_slice(slice_size)
574
+
575
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
576
+ for attn in self.attentions:
577
+ attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
578
+
579
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
580
+ output_states = ()
581
+ cross_attn_prob_list = []
582
+ for layer_idx, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
583
+ if self.training and self.gradient_checkpointing:
584
+
585
+ def create_custom_forward(module, return_dict=None):
586
+ def custom_forward(*inputs):
587
+ if return_dict is not None:
588
+ return module(*inputs, return_dict=return_dict)
589
+ else:
590
+ return module(*inputs)
591
+
592
+ return custom_forward
593
+
594
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
595
+ hidden_states = torch.utils.checkpoint.checkpoint(
596
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
597
+ )[0]
598
+ else:
599
+ hidden_states = resnet(hidden_states, temb)
600
+ tmp_hidden_states, cross_attn_prob = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
601
+ hidden_states = tmp_hidden_states.sample
602
+
603
+ output_states += (hidden_states,)
604
+ cross_attn_prob_list.append(cross_attn_prob)
605
+ if self.downsamplers is not None:
606
+ for downsampler in self.downsamplers:
607
+ hidden_states = downsampler(hidden_states)
608
+
609
+ output_states += (hidden_states,)
610
+
611
+ return hidden_states, output_states, cross_attn_prob_list
612
+
613
+
614
+ class DownBlock2D(nn.Module):
615
+ def __init__(
616
+ self,
617
+ in_channels: int,
618
+ out_channels: int,
619
+ temb_channels: int,
620
+ dropout: float = 0.0,
621
+ num_layers: int = 1,
622
+ resnet_eps: float = 1e-6,
623
+ resnet_time_scale_shift: str = "default",
624
+ resnet_act_fn: str = "swish",
625
+ resnet_groups: int = 32,
626
+ resnet_pre_norm: bool = True,
627
+ output_scale_factor=1.0,
628
+ add_downsample=True,
629
+ downsample_padding=1,
630
+ ):
631
+ super().__init__()
632
+ resnets = []
633
+
634
+ for i in range(num_layers):
635
+ in_channels = in_channels if i == 0 else out_channels
636
+ resnets.append(
637
+ ResnetBlock2D(
638
+ in_channels=in_channels,
639
+ out_channels=out_channels,
640
+ temb_channels=temb_channels,
641
+ eps=resnet_eps,
642
+ groups=resnet_groups,
643
+ dropout=dropout,
644
+ time_embedding_norm=resnet_time_scale_shift,
645
+ non_linearity=resnet_act_fn,
646
+ output_scale_factor=output_scale_factor,
647
+ pre_norm=resnet_pre_norm,
648
+ )
649
+ )
650
+
651
+ self.resnets = nn.ModuleList(resnets)
652
+
653
+ if add_downsample:
654
+ self.downsamplers = nn.ModuleList(
655
+ [
656
+ Downsample2D(
657
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
658
+ )
659
+ ]
660
+ )
661
+ else:
662
+ self.downsamplers = None
663
+
664
+ self.gradient_checkpointing = False
665
+
666
+ def forward(self, hidden_states, temb=None):
667
+ output_states = ()
668
+
669
+ for resnet in self.resnets:
670
+ if self.training and self.gradient_checkpointing:
671
+
672
+ def create_custom_forward(module):
673
+ def custom_forward(*inputs):
674
+ return module(*inputs)
675
+
676
+ return custom_forward
677
+
678
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
679
+ else:
680
+ hidden_states = resnet(hidden_states, temb)
681
+
682
+ output_states += (hidden_states,)
683
+
684
+ if self.downsamplers is not None:
685
+ for downsampler in self.downsamplers:
686
+ hidden_states = downsampler(hidden_states)
687
+
688
+ output_states += (hidden_states,)
689
+
690
+ return hidden_states, output_states
691
+
692
+
693
+ class DownEncoderBlock2D(nn.Module):
694
+ def __init__(
695
+ self,
696
+ in_channels: int,
697
+ out_channels: int,
698
+ dropout: float = 0.0,
699
+ num_layers: int = 1,
700
+ resnet_eps: float = 1e-6,
701
+ resnet_time_scale_shift: str = "default",
702
+ resnet_act_fn: str = "swish",
703
+ resnet_groups: int = 32,
704
+ resnet_pre_norm: bool = True,
705
+ output_scale_factor=1.0,
706
+ add_downsample=True,
707
+ downsample_padding=1,
708
+ ):
709
+ super().__init__()
710
+ resnets = []
711
+
712
+ for i in range(num_layers):
713
+ in_channels = in_channels if i == 0 else out_channels
714
+ resnets.append(
715
+ ResnetBlock2D(
716
+ in_channels=in_channels,
717
+ out_channels=out_channels,
718
+ temb_channels=None,
719
+ eps=resnet_eps,
720
+ groups=resnet_groups,
721
+ dropout=dropout,
722
+ time_embedding_norm=resnet_time_scale_shift,
723
+ non_linearity=resnet_act_fn,
724
+ output_scale_factor=output_scale_factor,
725
+ pre_norm=resnet_pre_norm,
726
+ )
727
+ )
728
+
729
+ self.resnets = nn.ModuleList(resnets)
730
+
731
+ if add_downsample:
732
+ self.downsamplers = nn.ModuleList(
733
+ [
734
+ Downsample2D(
735
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
736
+ )
737
+ ]
738
+ )
739
+ else:
740
+ self.downsamplers = None
741
+
742
+ def forward(self, hidden_states):
743
+ for resnet in self.resnets:
744
+ hidden_states = resnet(hidden_states, temb=None)
745
+
746
+ if self.downsamplers is not None:
747
+ for downsampler in self.downsamplers:
748
+ hidden_states = downsampler(hidden_states)
749
+
750
+ return hidden_states
751
+
752
+
753
+ class AttnDownEncoderBlock2D(nn.Module):
754
+ def __init__(
755
+ self,
756
+ in_channels: int,
757
+ out_channels: int,
758
+ dropout: float = 0.0,
759
+ num_layers: int = 1,
760
+ resnet_eps: float = 1e-6,
761
+ resnet_time_scale_shift: str = "default",
762
+ resnet_act_fn: str = "swish",
763
+ resnet_groups: int = 32,
764
+ resnet_pre_norm: bool = True,
765
+ attn_num_head_channels=1,
766
+ output_scale_factor=1.0,
767
+ add_downsample=True,
768
+ downsample_padding=1,
769
+ ):
770
+ super().__init__()
771
+ resnets = []
772
+ attentions = []
773
+
774
+ for i in range(num_layers):
775
+ in_channels = in_channels if i == 0 else out_channels
776
+ resnets.append(
777
+ ResnetBlock2D(
778
+ in_channels=in_channels,
779
+ out_channels=out_channels,
780
+ temb_channels=None,
781
+ eps=resnet_eps,
782
+ groups=resnet_groups,
783
+ dropout=dropout,
784
+ time_embedding_norm=resnet_time_scale_shift,
785
+ non_linearity=resnet_act_fn,
786
+ output_scale_factor=output_scale_factor,
787
+ pre_norm=resnet_pre_norm,
788
+ )
789
+ )
790
+ attentions.append(
791
+ AttentionBlock(
792
+ out_channels,
793
+ num_head_channels=attn_num_head_channels,
794
+ rescale_output_factor=output_scale_factor,
795
+ eps=resnet_eps,
796
+ norm_num_groups=resnet_groups,
797
+ )
798
+ )
799
+
800
+ self.attentions = nn.ModuleList(attentions)
801
+ self.resnets = nn.ModuleList(resnets)
802
+
803
+ if add_downsample:
804
+ self.downsamplers = nn.ModuleList(
805
+ [
806
+ Downsample2D(
807
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
808
+ )
809
+ ]
810
+ )
811
+ else:
812
+ self.downsamplers = None
813
+
814
+ def forward(self, hidden_states):
815
+ for resnet, attn in zip(self.resnets, self.attentions):
816
+ hidden_states = resnet(hidden_states, temb=None)
817
+ hidden_states = attn(hidden_states)
818
+
819
+ if self.downsamplers is not None:
820
+ for downsampler in self.downsamplers:
821
+ hidden_states = downsampler(hidden_states)
822
+
823
+ return hidden_states
824
+
825
+
826
+ class AttnSkipDownBlock2D(nn.Module):
827
+ def __init__(
828
+ self,
829
+ in_channels: int,
830
+ out_channels: int,
831
+ temb_channels: int,
832
+ dropout: float = 0.0,
833
+ num_layers: int = 1,
834
+ resnet_eps: float = 1e-6,
835
+ resnet_time_scale_shift: str = "default",
836
+ resnet_act_fn: str = "swish",
837
+ resnet_pre_norm: bool = True,
838
+ attn_num_head_channels=1,
839
+ attention_type="default",
840
+ output_scale_factor=np.sqrt(2.0),
841
+ downsample_padding=1,
842
+ add_downsample=True,
843
+ ):
844
+ super().__init__()
845
+ self.attentions = nn.ModuleList([])
846
+ self.resnets = nn.ModuleList([])
847
+
848
+ self.attention_type = attention_type
849
+
850
+ for i in range(num_layers):
851
+ in_channels = in_channels if i == 0 else out_channels
852
+ self.resnets.append(
853
+ ResnetBlock2D(
854
+ in_channels=in_channels,
855
+ out_channels=out_channels,
856
+ temb_channels=temb_channels,
857
+ eps=resnet_eps,
858
+ groups=min(in_channels // 4, 32),
859
+ groups_out=min(out_channels // 4, 32),
860
+ dropout=dropout,
861
+ time_embedding_norm=resnet_time_scale_shift,
862
+ non_linearity=resnet_act_fn,
863
+ output_scale_factor=output_scale_factor,
864
+ pre_norm=resnet_pre_norm,
865
+ )
866
+ )
867
+ self.attentions.append(
868
+ AttentionBlock(
869
+ out_channels,
870
+ num_head_channels=attn_num_head_channels,
871
+ rescale_output_factor=output_scale_factor,
872
+ eps=resnet_eps,
873
+ )
874
+ )
875
+
876
+ if add_downsample:
877
+ self.resnet_down = ResnetBlock2D(
878
+ in_channels=out_channels,
879
+ out_channels=out_channels,
880
+ temb_channels=temb_channels,
881
+ eps=resnet_eps,
882
+ groups=min(out_channels // 4, 32),
883
+ dropout=dropout,
884
+ time_embedding_norm=resnet_time_scale_shift,
885
+ non_linearity=resnet_act_fn,
886
+ output_scale_factor=output_scale_factor,
887
+ pre_norm=resnet_pre_norm,
888
+ use_in_shortcut=True,
889
+ down=True,
890
+ kernel="fir",
891
+ )
892
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
893
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
894
+ else:
895
+ self.resnet_down = None
896
+ self.downsamplers = None
897
+ self.skip_conv = None
898
+
899
+ def forward(self, hidden_states, temb=None, skip_sample=None):
900
+ output_states = ()
901
+
902
+ for resnet, attn in zip(self.resnets, self.attentions):
903
+ hidden_states = resnet(hidden_states, temb)
904
+ hidden_states = attn(hidden_states)
905
+ output_states += (hidden_states,)
906
+
907
+ if self.downsamplers is not None:
908
+ hidden_states = self.resnet_down(hidden_states, temb)
909
+ for downsampler in self.downsamplers:
910
+ skip_sample = downsampler(skip_sample)
911
+
912
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
913
+
914
+ output_states += (hidden_states,)
915
+
916
+ return hidden_states, output_states, skip_sample
917
+
918
+
919
+ class SkipDownBlock2D(nn.Module):
920
+ def __init__(
921
+ self,
922
+ in_channels: int,
923
+ out_channels: int,
924
+ temb_channels: int,
925
+ dropout: float = 0.0,
926
+ num_layers: int = 1,
927
+ resnet_eps: float = 1e-6,
928
+ resnet_time_scale_shift: str = "default",
929
+ resnet_act_fn: str = "swish",
930
+ resnet_pre_norm: bool = True,
931
+ output_scale_factor=np.sqrt(2.0),
932
+ add_downsample=True,
933
+ downsample_padding=1,
934
+ ):
935
+ super().__init__()
936
+ self.resnets = nn.ModuleList([])
937
+
938
+ for i in range(num_layers):
939
+ in_channels = in_channels if i == 0 else out_channels
940
+ self.resnets.append(
941
+ ResnetBlock2D(
942
+ in_channels=in_channels,
943
+ out_channels=out_channels,
944
+ temb_channels=temb_channels,
945
+ eps=resnet_eps,
946
+ groups=min(in_channels // 4, 32),
947
+ groups_out=min(out_channels // 4, 32),
948
+ dropout=dropout,
949
+ time_embedding_norm=resnet_time_scale_shift,
950
+ non_linearity=resnet_act_fn,
951
+ output_scale_factor=output_scale_factor,
952
+ pre_norm=resnet_pre_norm,
953
+ )
954
+ )
955
+
956
+ if add_downsample:
957
+ self.resnet_down = ResnetBlock2D(
958
+ in_channels=out_channels,
959
+ out_channels=out_channels,
960
+ temb_channels=temb_channels,
961
+ eps=resnet_eps,
962
+ groups=min(out_channels // 4, 32),
963
+ dropout=dropout,
964
+ time_embedding_norm=resnet_time_scale_shift,
965
+ non_linearity=resnet_act_fn,
966
+ output_scale_factor=output_scale_factor,
967
+ pre_norm=resnet_pre_norm,
968
+ use_in_shortcut=True,
969
+ down=True,
970
+ kernel="fir",
971
+ )
972
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
973
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
974
+ else:
975
+ self.resnet_down = None
976
+ self.downsamplers = None
977
+ self.skip_conv = None
978
+
979
+ def forward(self, hidden_states, temb=None, skip_sample=None):
980
+ output_states = ()
981
+
982
+ for resnet in self.resnets:
983
+ hidden_states = resnet(hidden_states, temb)
984
+ output_states += (hidden_states,)
985
+
986
+ if self.downsamplers is not None:
987
+ hidden_states = self.resnet_down(hidden_states, temb)
988
+ for downsampler in self.downsamplers:
989
+ skip_sample = downsampler(skip_sample)
990
+
991
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
992
+
993
+ output_states += (hidden_states,)
994
+
995
+ return hidden_states, output_states, skip_sample
996
+
997
+
998
+ class AttnUpBlock2D(nn.Module):
999
+ def __init__(
1000
+ self,
1001
+ in_channels: int,
1002
+ prev_output_channel: int,
1003
+ out_channels: int,
1004
+ temb_channels: int,
1005
+ dropout: float = 0.0,
1006
+ num_layers: int = 1,
1007
+ resnet_eps: float = 1e-6,
1008
+ resnet_time_scale_shift: str = "default",
1009
+ resnet_act_fn: str = "swish",
1010
+ resnet_groups: int = 32,
1011
+ resnet_pre_norm: bool = True,
1012
+ attention_type="default",
1013
+ attn_num_head_channels=1,
1014
+ output_scale_factor=1.0,
1015
+ add_upsample=True,
1016
+ ):
1017
+ super().__init__()
1018
+ resnets = []
1019
+ attentions = []
1020
+
1021
+ self.attention_type = attention_type
1022
+
1023
+ for i in range(num_layers):
1024
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1025
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1026
+
1027
+ resnets.append(
1028
+ ResnetBlock2D(
1029
+ in_channels=resnet_in_channels + res_skip_channels,
1030
+ out_channels=out_channels,
1031
+ temb_channels=temb_channels,
1032
+ eps=resnet_eps,
1033
+ groups=resnet_groups,
1034
+ dropout=dropout,
1035
+ time_embedding_norm=resnet_time_scale_shift,
1036
+ non_linearity=resnet_act_fn,
1037
+ output_scale_factor=output_scale_factor,
1038
+ pre_norm=resnet_pre_norm,
1039
+ )
1040
+ )
1041
+ attentions.append(
1042
+ AttentionBlock(
1043
+ out_channels,
1044
+ num_head_channels=attn_num_head_channels,
1045
+ rescale_output_factor=output_scale_factor,
1046
+ eps=resnet_eps,
1047
+ norm_num_groups=resnet_groups,
1048
+ )
1049
+ )
1050
+
1051
+ self.attentions = nn.ModuleList(attentions)
1052
+ self.resnets = nn.ModuleList(resnets)
1053
+
1054
+ if add_upsample:
1055
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1056
+ else:
1057
+ self.upsamplers = None
1058
+
1059
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
1060
+ for resnet, attn in zip(self.resnets, self.attentions):
1061
+ # pop res hidden states
1062
+ res_hidden_states = res_hidden_states_tuple[-1]
1063
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1064
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1065
+
1066
+ hidden_states = resnet(hidden_states, temb)
1067
+ hidden_states = attn(hidden_states)
1068
+
1069
+ if self.upsamplers is not None:
1070
+ for upsampler in self.upsamplers:
1071
+ hidden_states = upsampler(hidden_states)
1072
+
1073
+ return hidden_states
1074
+
1075
+
1076
+ class CrossAttnUpBlock2D(nn.Module):
1077
+ def __init__(
1078
+ self,
1079
+ in_channels: int,
1080
+ out_channels: int,
1081
+ prev_output_channel: int,
1082
+ temb_channels: int,
1083
+ dropout: float = 0.0,
1084
+ num_layers: int = 1,
1085
+ resnet_eps: float = 1e-6,
1086
+ resnet_time_scale_shift: str = "default",
1087
+ resnet_act_fn: str = "swish",
1088
+ resnet_groups: int = 32,
1089
+ resnet_pre_norm: bool = True,
1090
+ attn_num_head_channels=1,
1091
+ cross_attention_dim=1280,
1092
+ attention_type="default",
1093
+ output_scale_factor=1.0,
1094
+ add_upsample=True,
1095
+ ):
1096
+ super().__init__()
1097
+ resnets = []
1098
+ attentions = []
1099
+
1100
+ self.attention_type = attention_type
1101
+ self.attn_num_head_channels = attn_num_head_channels
1102
+
1103
+ for i in range(num_layers):
1104
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1105
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1106
+
1107
+ resnets.append(
1108
+ ResnetBlock2D(
1109
+ in_channels=resnet_in_channels + res_skip_channels,
1110
+ out_channels=out_channels,
1111
+ temb_channels=temb_channels,
1112
+ eps=resnet_eps,
1113
+ groups=resnet_groups,
1114
+ dropout=dropout,
1115
+ time_embedding_norm=resnet_time_scale_shift,
1116
+ non_linearity=resnet_act_fn,
1117
+ output_scale_factor=output_scale_factor,
1118
+ pre_norm=resnet_pre_norm,
1119
+ )
1120
+ )
1121
+ attentions.append(
1122
+ Transformer2DModel(
1123
+ attn_num_head_channels,
1124
+ out_channels // attn_num_head_channels,
1125
+ in_channels=out_channels,
1126
+ num_layers=1,
1127
+ cross_attention_dim=cross_attention_dim,
1128
+ norm_num_groups=resnet_groups,
1129
+ )
1130
+ )
1131
+ self.attentions = nn.ModuleList(attentions)
1132
+ self.resnets = nn.ModuleList(resnets)
1133
+
1134
+ if add_upsample:
1135
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1136
+ else:
1137
+ self.upsamplers = None
1138
+
1139
+ self.gradient_checkpointing = False
1140
+
1141
+ def set_attention_slice(self, slice_size):
1142
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
1143
+ raise ValueError(
1144
+ f"Make sure slice_size {slice_size} is a divisor of "
1145
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
1146
+ )
1147
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
1148
+ raise ValueError(
1149
+ f"Chunk_size {slice_size} has to be smaller or equal to "
1150
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
1151
+ )
1152
+
1153
+ for attn in self.attentions:
1154
+ attn._set_attention_slice(slice_size)
1155
+
1156
+ self.gradient_checkpointing = False
1157
+
1158
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
1159
+ for attn in self.attentions:
1160
+ attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
1161
+
1162
+ def forward(
1163
+ self,
1164
+ hidden_states,
1165
+ res_hidden_states_tuple,
1166
+ temb=None,
1167
+ encoder_hidden_states=None,
1168
+ upsample_size=None,
1169
+ ):
1170
+ cross_attn_prob_list = list()
1171
+ for layer_idx, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
1172
+ # pop res hidden states
1173
+ res_hidden_states = res_hidden_states_tuple[-1]
1174
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1175
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1176
+
1177
+ if self.training and self.gradient_checkpointing:
1178
+
1179
+ def create_custom_forward(module, return_dict=None):
1180
+ def custom_forward(*inputs):
1181
+ if return_dict is not None:
1182
+ return module(*inputs, return_dict=return_dict)
1183
+ else:
1184
+ return module(*inputs)
1185
+
1186
+ return custom_forward
1187
+
1188
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1189
+ hidden_states = torch.utils.checkpoint.checkpoint(
1190
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
1191
+ )[0]
1192
+ else:
1193
+ hidden_states = resnet(hidden_states, temb)
1194
+ tmp_hidden_states, cross_attn_prob = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
1195
+ hidden_states = tmp_hidden_states.sample
1196
+ cross_attn_prob_list.append(cross_attn_prob)
1197
+ if self.upsamplers is not None:
1198
+ for upsampler in self.upsamplers:
1199
+ hidden_states = upsampler(hidden_states, upsample_size)
1200
+
1201
+ return hidden_states, cross_attn_prob_list
1202
+
1203
+
1204
+ class UpBlock2D(nn.Module):
1205
+ def __init__(
1206
+ self,
1207
+ in_channels: int,
1208
+ prev_output_channel: int,
1209
+ out_channels: int,
1210
+ temb_channels: int,
1211
+ dropout: float = 0.0,
1212
+ num_layers: int = 1,
1213
+ resnet_eps: float = 1e-6,
1214
+ resnet_time_scale_shift: str = "default",
1215
+ resnet_act_fn: str = "swish",
1216
+ resnet_groups: int = 32,
1217
+ resnet_pre_norm: bool = True,
1218
+ output_scale_factor=1.0,
1219
+ add_upsample=True,
1220
+ ):
1221
+ super().__init__()
1222
+ resnets = []
1223
+
1224
+ for i in range(num_layers):
1225
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1226
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1227
+
1228
+ resnets.append(
1229
+ ResnetBlock2D(
1230
+ in_channels=resnet_in_channels + res_skip_channels,
1231
+ out_channels=out_channels,
1232
+ temb_channels=temb_channels,
1233
+ eps=resnet_eps,
1234
+ groups=resnet_groups,
1235
+ dropout=dropout,
1236
+ time_embedding_norm=resnet_time_scale_shift,
1237
+ non_linearity=resnet_act_fn,
1238
+ output_scale_factor=output_scale_factor,
1239
+ pre_norm=resnet_pre_norm,
1240
+ )
1241
+ )
1242
+
1243
+ self.resnets = nn.ModuleList(resnets)
1244
+
1245
+ if add_upsample:
1246
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1247
+ else:
1248
+ self.upsamplers = None
1249
+
1250
+ self.gradient_checkpointing = False
1251
+
1252
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
1253
+ for resnet in self.resnets:
1254
+ # pop res hidden states
1255
+ res_hidden_states = res_hidden_states_tuple[-1]
1256
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1257
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1258
+
1259
+ if self.training and self.gradient_checkpointing:
1260
+
1261
+ def create_custom_forward(module):
1262
+ def custom_forward(*inputs):
1263
+ return module(*inputs)
1264
+
1265
+ return custom_forward
1266
+
1267
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1268
+ else:
1269
+ hidden_states = resnet(hidden_states, temb)
1270
+
1271
+ if self.upsamplers is not None:
1272
+ for upsampler in self.upsamplers:
1273
+ hidden_states = upsampler(hidden_states, upsample_size)
1274
+
1275
+ return hidden_states
1276
+
1277
+
1278
+ class UpDecoderBlock2D(nn.Module):
1279
+ def __init__(
1280
+ self,
1281
+ in_channels: int,
1282
+ out_channels: int,
1283
+ dropout: float = 0.0,
1284
+ num_layers: int = 1,
1285
+ resnet_eps: float = 1e-6,
1286
+ resnet_time_scale_shift: str = "default",
1287
+ resnet_act_fn: str = "swish",
1288
+ resnet_groups: int = 32,
1289
+ resnet_pre_norm: bool = True,
1290
+ output_scale_factor=1.0,
1291
+ add_upsample=True,
1292
+ ):
1293
+ super().__init__()
1294
+ resnets = []
1295
+
1296
+ for i in range(num_layers):
1297
+ input_channels = in_channels if i == 0 else out_channels
1298
+
1299
+ resnets.append(
1300
+ ResnetBlock2D(
1301
+ in_channels=input_channels,
1302
+ out_channels=out_channels,
1303
+ temb_channels=None,
1304
+ eps=resnet_eps,
1305
+ groups=resnet_groups,
1306
+ dropout=dropout,
1307
+ time_embedding_norm=resnet_time_scale_shift,
1308
+ non_linearity=resnet_act_fn,
1309
+ output_scale_factor=output_scale_factor,
1310
+ pre_norm=resnet_pre_norm,
1311
+ )
1312
+ )
1313
+
1314
+ self.resnets = nn.ModuleList(resnets)
1315
+
1316
+ if add_upsample:
1317
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1318
+ else:
1319
+ self.upsamplers = None
1320
+
1321
+ def forward(self, hidden_states):
1322
+ for resnet in self.resnets:
1323
+ hidden_states = resnet(hidden_states, temb=None)
1324
+
1325
+ if self.upsamplers is not None:
1326
+ for upsampler in self.upsamplers:
1327
+ hidden_states = upsampler(hidden_states)
1328
+
1329
+ return hidden_states
1330
+
1331
+
1332
+ class AttnUpDecoderBlock2D(nn.Module):
1333
+ def __init__(
1334
+ self,
1335
+ in_channels: int,
1336
+ out_channels: int,
1337
+ dropout: float = 0.0,
1338
+ num_layers: int = 1,
1339
+ resnet_eps: float = 1e-6,
1340
+ resnet_time_scale_shift: str = "default",
1341
+ resnet_act_fn: str = "swish",
1342
+ resnet_groups: int = 32,
1343
+ resnet_pre_norm: bool = True,
1344
+ attn_num_head_channels=1,
1345
+ output_scale_factor=1.0,
1346
+ add_upsample=True,
1347
+ ):
1348
+ super().__init__()
1349
+ resnets = []
1350
+ attentions = []
1351
+
1352
+ for i in range(num_layers):
1353
+ input_channels = in_channels if i == 0 else out_channels
1354
+
1355
+ resnets.append(
1356
+ ResnetBlock2D(
1357
+ in_channels=input_channels,
1358
+ out_channels=out_channels,
1359
+ temb_channels=None,
1360
+ eps=resnet_eps,
1361
+ groups=resnet_groups,
1362
+ dropout=dropout,
1363
+ time_embedding_norm=resnet_time_scale_shift,
1364
+ non_linearity=resnet_act_fn,
1365
+ output_scale_factor=output_scale_factor,
1366
+ pre_norm=resnet_pre_norm,
1367
+ )
1368
+ )
1369
+ attentions.append(
1370
+ AttentionBlock(
1371
+ out_channels,
1372
+ num_head_channels=attn_num_head_channels,
1373
+ rescale_output_factor=output_scale_factor,
1374
+ eps=resnet_eps,
1375
+ norm_num_groups=resnet_groups,
1376
+ )
1377
+ )
1378
+
1379
+ self.attentions = nn.ModuleList(attentions)
1380
+ self.resnets = nn.ModuleList(resnets)
1381
+
1382
+ if add_upsample:
1383
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1384
+ else:
1385
+ self.upsamplers = None
1386
+
1387
+ def forward(self, hidden_states):
1388
+ for resnet, attn in zip(self.resnets, self.attentions):
1389
+ hidden_states = resnet(hidden_states, temb=None)
1390
+ hidden_states = attn(hidden_states)
1391
+
1392
+ if self.upsamplers is not None:
1393
+ for upsampler in self.upsamplers:
1394
+ hidden_states = upsampler(hidden_states)
1395
+
1396
+ return hidden_states
1397
+
1398
+
1399
+ class AttnSkipUpBlock2D(nn.Module):
1400
+ def __init__(
1401
+ self,
1402
+ in_channels: int,
1403
+ prev_output_channel: int,
1404
+ out_channels: int,
1405
+ temb_channels: int,
1406
+ dropout: float = 0.0,
1407
+ num_layers: int = 1,
1408
+ resnet_eps: float = 1e-6,
1409
+ resnet_time_scale_shift: str = "default",
1410
+ resnet_act_fn: str = "swish",
1411
+ resnet_pre_norm: bool = True,
1412
+ attn_num_head_channels=1,
1413
+ attention_type="default",
1414
+ output_scale_factor=np.sqrt(2.0),
1415
+ upsample_padding=1,
1416
+ add_upsample=True,
1417
+ ):
1418
+ super().__init__()
1419
+ self.attentions = nn.ModuleList([])
1420
+ self.resnets = nn.ModuleList([])
1421
+
1422
+ self.attention_type = attention_type
1423
+
1424
+ for i in range(num_layers):
1425
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1426
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1427
+
1428
+ self.resnets.append(
1429
+ ResnetBlock2D(
1430
+ in_channels=resnet_in_channels + res_skip_channels,
1431
+ out_channels=out_channels,
1432
+ temb_channels=temb_channels,
1433
+ eps=resnet_eps,
1434
+ groups=min(resnet_in_channels + res_skip_channels // 4, 32),
1435
+ groups_out=min(out_channels // 4, 32),
1436
+ dropout=dropout,
1437
+ time_embedding_norm=resnet_time_scale_shift,
1438
+ non_linearity=resnet_act_fn,
1439
+ output_scale_factor=output_scale_factor,
1440
+ pre_norm=resnet_pre_norm,
1441
+ )
1442
+ )
1443
+
1444
+ self.attentions.append(
1445
+ AttentionBlock(
1446
+ out_channels,
1447
+ num_head_channels=attn_num_head_channels,
1448
+ rescale_output_factor=output_scale_factor,
1449
+ eps=resnet_eps,
1450
+ )
1451
+ )
1452
+
1453
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1454
+ if add_upsample:
1455
+ self.resnet_up = ResnetBlock2D(
1456
+ in_channels=out_channels,
1457
+ out_channels=out_channels,
1458
+ temb_channels=temb_channels,
1459
+ eps=resnet_eps,
1460
+ groups=min(out_channels // 4, 32),
1461
+ groups_out=min(out_channels // 4, 32),
1462
+ dropout=dropout,
1463
+ time_embedding_norm=resnet_time_scale_shift,
1464
+ non_linearity=resnet_act_fn,
1465
+ output_scale_factor=output_scale_factor,
1466
+ pre_norm=resnet_pre_norm,
1467
+ use_in_shortcut=True,
1468
+ up=True,
1469
+ kernel="fir",
1470
+ )
1471
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1472
+ self.skip_norm = torch.nn.GroupNorm(
1473
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1474
+ )
1475
+ self.act = nn.SiLU()
1476
+ else:
1477
+ self.resnet_up = None
1478
+ self.skip_conv = None
1479
+ self.skip_norm = None
1480
+ self.act = None
1481
+
1482
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1483
+ for resnet in self.resnets:
1484
+ # pop res hidden states
1485
+ res_hidden_states = res_hidden_states_tuple[-1]
1486
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1487
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1488
+
1489
+ hidden_states = resnet(hidden_states, temb)
1490
+
1491
+ hidden_states = self.attentions[0](hidden_states)
1492
+
1493
+ if skip_sample is not None:
1494
+ skip_sample = self.upsampler(skip_sample)
1495
+ else:
1496
+ skip_sample = 0
1497
+
1498
+ if self.resnet_up is not None:
1499
+ skip_sample_states = self.skip_norm(hidden_states)
1500
+ skip_sample_states = self.act(skip_sample_states)
1501
+ skip_sample_states = self.skip_conv(skip_sample_states)
1502
+
1503
+ skip_sample = skip_sample + skip_sample_states
1504
+
1505
+ hidden_states = self.resnet_up(hidden_states, temb)
1506
+
1507
+ return hidden_states, skip_sample
1508
+
1509
+
1510
+ class SkipUpBlock2D(nn.Module):
1511
+ def __init__(
1512
+ self,
1513
+ in_channels: int,
1514
+ prev_output_channel: int,
1515
+ out_channels: int,
1516
+ temb_channels: int,
1517
+ dropout: float = 0.0,
1518
+ num_layers: int = 1,
1519
+ resnet_eps: float = 1e-6,
1520
+ resnet_time_scale_shift: str = "default",
1521
+ resnet_act_fn: str = "swish",
1522
+ resnet_pre_norm: bool = True,
1523
+ output_scale_factor=np.sqrt(2.0),
1524
+ add_upsample=True,
1525
+ upsample_padding=1,
1526
+ ):
1527
+ super().__init__()
1528
+ self.resnets = nn.ModuleList([])
1529
+
1530
+ for i in range(num_layers):
1531
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1532
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1533
+
1534
+ self.resnets.append(
1535
+ ResnetBlock2D(
1536
+ in_channels=resnet_in_channels + res_skip_channels,
1537
+ out_channels=out_channels,
1538
+ temb_channels=temb_channels,
1539
+ eps=resnet_eps,
1540
+ groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
1541
+ groups_out=min(out_channels // 4, 32),
1542
+ dropout=dropout,
1543
+ time_embedding_norm=resnet_time_scale_shift,
1544
+ non_linearity=resnet_act_fn,
1545
+ output_scale_factor=output_scale_factor,
1546
+ pre_norm=resnet_pre_norm,
1547
+ )
1548
+ )
1549
+
1550
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1551
+ if add_upsample:
1552
+ self.resnet_up = ResnetBlock2D(
1553
+ in_channels=out_channels,
1554
+ out_channels=out_channels,
1555
+ temb_channels=temb_channels,
1556
+ eps=resnet_eps,
1557
+ groups=min(out_channels // 4, 32),
1558
+ groups_out=min(out_channels // 4, 32),
1559
+ dropout=dropout,
1560
+ time_embedding_norm=resnet_time_scale_shift,
1561
+ non_linearity=resnet_act_fn,
1562
+ output_scale_factor=output_scale_factor,
1563
+ pre_norm=resnet_pre_norm,
1564
+ use_in_shortcut=True,
1565
+ up=True,
1566
+ kernel="fir",
1567
+ )
1568
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1569
+ self.skip_norm = torch.nn.GroupNorm(
1570
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1571
+ )
1572
+ self.act = nn.SiLU()
1573
+ else:
1574
+ self.resnet_up = None
1575
+ self.skip_conv = None
1576
+ self.skip_norm = None
1577
+ self.act = None
1578
+
1579
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1580
+ for resnet in self.resnets:
1581
+ # pop res hidden states
1582
+ res_hidden_states = res_hidden_states_tuple[-1]
1583
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1584
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1585
+
1586
+ hidden_states = resnet(hidden_states, temb)
1587
+
1588
+ if skip_sample is not None:
1589
+ skip_sample = self.upsampler(skip_sample)
1590
+ else:
1591
+ skip_sample = 0
1592
+
1593
+ if self.resnet_up is not None:
1594
+ skip_sample_states = self.skip_norm(hidden_states)
1595
+ skip_sample_states = self.act(skip_sample_states)
1596
+ skip_sample_states = self.skip_conv(skip_sample_states)
1597
+
1598
+ skip_sample = skip_sample + skip_sample_states
1599
+
1600
+ hidden_states = self.resnet_up(hidden_states, temb)
1601
+
1602
+ return hidden_states, skip_sample
my_model/unet_2d_condition.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import pdb
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.modeling_utils import ModelMixin
24
+ from diffusers.utils import BaseOutput, logging
25
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
26
+ from .unet_2d_blocks import (
27
+ CrossAttnDownBlock2D,
28
+ CrossAttnUpBlock2D,
29
+ DownBlock2D,
30
+ UNetMidBlock2DCrossAttn,
31
+ UpBlock2D,
32
+ get_down_block,
33
+ get_up_block,
34
+ )
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ @dataclass
41
+ class UNet2DConditionOutput(BaseOutput):
42
+ """
43
+ Args:
44
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
45
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
46
+ """
47
+
48
+ sample: torch.FloatTensor
49
+
50
+
51
+ class UNet2DConditionModel(ModelMixin, ConfigMixin):
52
+ r"""
53
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
54
+ and returns sample shaped output.
55
+
56
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
57
+ implements for all the models (such as downloading or saving, etc.)
58
+
59
+ Parameters:
60
+ sample_size (`int`, *optional*): The size of the input sample.
61
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
62
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
63
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
64
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
65
+ Whether to flip the sin to cos in the time embedding.
66
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
67
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
68
+ The tuple of downsample blocks to use.
69
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
70
+ The tuple of upsample blocks to use.
71
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
72
+ The tuple of output channels for each block.
73
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
74
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
75
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
76
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
77
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
78
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
79
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
80
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
81
+ """
82
+
83
+ _supports_gradient_checkpointing = True
84
+
85
+ @register_to_config
86
+ def __init__(
87
+ self,
88
+ sample_size: Optional[int] = None,
89
+ in_channels: int = 4,
90
+ out_channels: int = 4,
91
+ center_input_sample: bool = False,
92
+ flip_sin_to_cos: bool = True,
93
+ freq_shift: int = 0,
94
+ down_block_types: Tuple[str] = (
95
+ "CrossAttnDownBlock2D",
96
+ "CrossAttnDownBlock2D",
97
+ "CrossAttnDownBlock2D",
98
+ "DownBlock2D",
99
+ ),
100
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
101
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
102
+ layers_per_block: int = 2,
103
+ downsample_padding: int = 1,
104
+ mid_block_scale_factor: float = 1,
105
+ act_fn: str = "silu",
106
+ norm_num_groups: int = 32,
107
+ norm_eps: float = 1e-5,
108
+ cross_attention_dim: int = 1280,
109
+ attention_head_dim: int = 8,
110
+ ):
111
+ super().__init__()
112
+
113
+ self.sample_size = sample_size
114
+ time_embed_dim = block_out_channels[0] * 4
115
+
116
+ # input
117
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
118
+
119
+ # time
120
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
121
+ timestep_input_dim = block_out_channels[0]
122
+
123
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
124
+
125
+ self.down_blocks = nn.ModuleList([])
126
+ self.mid_block = None
127
+ self.up_blocks = nn.ModuleList([])
128
+
129
+ # down
130
+ output_channel = block_out_channels[0]
131
+ for i, down_block_type in enumerate(down_block_types):
132
+ input_channel = output_channel
133
+ output_channel = block_out_channels[i]
134
+ is_final_block = i == len(block_out_channels) - 1
135
+
136
+ down_block = get_down_block(
137
+ down_block_type,
138
+ num_layers=layers_per_block,
139
+ in_channels=input_channel,
140
+ out_channels=output_channel,
141
+ temb_channels=time_embed_dim,
142
+ add_downsample=not is_final_block,
143
+ resnet_eps=norm_eps,
144
+ resnet_act_fn=act_fn,
145
+ resnet_groups=norm_num_groups,
146
+ cross_attention_dim=cross_attention_dim,
147
+ attn_num_head_channels=attention_head_dim,
148
+ downsample_padding=downsample_padding,
149
+ )
150
+ self.down_blocks.append(down_block)
151
+
152
+ # mid
153
+ self.mid_block = UNetMidBlock2DCrossAttn(
154
+ in_channels=block_out_channels[-1],
155
+ temb_channels=time_embed_dim,
156
+ resnet_eps=norm_eps,
157
+ resnet_act_fn=act_fn,
158
+ output_scale_factor=mid_block_scale_factor,
159
+ resnet_time_scale_shift="default",
160
+ cross_attention_dim=cross_attention_dim,
161
+ attn_num_head_channels=attention_head_dim,
162
+ resnet_groups=norm_num_groups,
163
+ )
164
+
165
+ # count how many layers upsample the images
166
+ self.num_upsamplers = 0
167
+
168
+ # up
169
+ reversed_block_out_channels = list(reversed(block_out_channels))
170
+ output_channel = reversed_block_out_channels[0]
171
+ for i, up_block_type in enumerate(up_block_types):
172
+ is_final_block = i == len(block_out_channels) - 1
173
+
174
+ prev_output_channel = output_channel
175
+ output_channel = reversed_block_out_channels[i]
176
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
177
+
178
+ # add upsample block for all BUT final layer
179
+ if not is_final_block:
180
+ add_upsample = True
181
+ self.num_upsamplers += 1
182
+ else:
183
+ add_upsample = False
184
+
185
+ up_block = get_up_block(
186
+ up_block_type,
187
+ num_layers=layers_per_block + 1,
188
+ in_channels=input_channel,
189
+ out_channels=output_channel,
190
+ prev_output_channel=prev_output_channel,
191
+ temb_channels=time_embed_dim,
192
+ add_upsample=add_upsample,
193
+ resnet_eps=norm_eps,
194
+ resnet_act_fn=act_fn,
195
+ resnet_groups=norm_num_groups,
196
+ cross_attention_dim=cross_attention_dim,
197
+ attn_num_head_channels=attention_head_dim,
198
+ )
199
+ self.up_blocks.append(up_block)
200
+ prev_output_channel = output_channel
201
+
202
+ # out
203
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
204
+ self.conv_act = nn.SiLU()
205
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
206
+
207
+ def set_attention_slice(self, slice_size):
208
+ if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
209
+ raise ValueError(
210
+ f"Make sure slice_size {slice_size} is a divisor of "
211
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
212
+ )
213
+ if slice_size is not None and slice_size > self.config.attention_head_dim:
214
+ raise ValueError(
215
+ f"Chunk_size {slice_size} has to be smaller or equal to "
216
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
217
+ )
218
+
219
+ for block in self.down_blocks:
220
+ if hasattr(block, "attentions") and block.attentions is not None:
221
+ block.set_attention_slice(slice_size)
222
+
223
+ self.mid_block.set_attention_slice(slice_size)
224
+
225
+ for block in self.up_blocks:
226
+ if hasattr(block, "attentions") and block.attentions is not None:
227
+ block.set_attention_slice(slice_size)
228
+
229
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
230
+ for block in self.down_blocks:
231
+ if hasattr(block, "attentions") and block.attentions is not None:
232
+ block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
233
+
234
+ self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
235
+
236
+ for block in self.up_blocks:
237
+ if hasattr(block, "attentions") and block.attentions is not None:
238
+ block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
239
+
240
+ def _set_gradient_checkpointing(self, module, value=False):
241
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
242
+ module.gradient_checkpointing = value
243
+
244
+ def forward(
245
+ self,
246
+ sample: torch.FloatTensor,
247
+ timestep: Union[torch.Tensor, float, int],
248
+ encoder_hidden_states: torch.Tensor,
249
+ return_dict: bool = True,
250
+ ) -> Union[UNet2DConditionOutput, Tuple]:
251
+ r"""
252
+ Args:
253
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs_coarse tensor
254
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
255
+ encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
256
+ return_dict (`bool`, *optional*, defaults to `True`):
257
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
258
+
259
+ Returns:
260
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
261
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
262
+ returning a tuple, the first element is the sample tensor.
263
+ """
264
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
265
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
266
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
267
+ # on the fly if necessary.
268
+ default_overall_up_factor = 2**self.num_upsamplers
269
+
270
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
271
+ forward_upsample_size = False
272
+ upsample_size = None
273
+
274
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
275
+ logger.info("Forward upsample size to force interpolation output size.")
276
+ forward_upsample_size = True
277
+
278
+ # 0. center input if necessary
279
+ if self.config.center_input_sample:
280
+ sample = 2 * sample - 1.0
281
+
282
+ # 1. time
283
+ timesteps = timestep
284
+ if not torch.is_tensor(timesteps):
285
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
286
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
287
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
288
+ timesteps = timesteps[None].to(sample.device)
289
+
290
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
291
+ timesteps = timesteps.expand(sample.shape[0])
292
+
293
+ t_emb = self.time_proj(timesteps)
294
+
295
+ # timesteps does not contain any weights and will always return f32 tensors
296
+ # but time_embedding might actually be running in fp16. so we need to cast here.
297
+ # there might be better ways to encapsulate this.
298
+ t_emb = t_emb.to(dtype=self.dtype)
299
+ emb = self.time_embedding(t_emb)
300
+ # 2. pre-process
301
+ sample = self.conv_in(sample)
302
+ # 3. down
303
+ attn_down = []
304
+ down_block_res_samples = (sample,)
305
+ for block_idx, downsample_block in enumerate(self.down_blocks):
306
+ if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
307
+ sample, res_samples, cross_atten_prob = downsample_block(
308
+ hidden_states=sample,
309
+ temb=emb,
310
+ encoder_hidden_states=encoder_hidden_states
311
+ )
312
+ attn_down.append(cross_atten_prob)
313
+ else:
314
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
315
+
316
+ down_block_res_samples += res_samples
317
+
318
+ # 4. mid
319
+ sample, attn_mid = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
320
+
321
+ # 5. up
322
+ attn_up = []
323
+ for i, upsample_block in enumerate(self.up_blocks):
324
+ is_final_block = i == len(self.up_blocks) - 1
325
+
326
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
327
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
328
+
329
+ # if we have not reached the final block and need to forward the
330
+ # upsample size, we do it here
331
+ if not is_final_block and forward_upsample_size:
332
+ upsample_size = down_block_res_samples[-1].shape[2:]
333
+
334
+ if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
335
+ sample, cross_atten_prob = upsample_block(
336
+ hidden_states=sample,
337
+ temb=emb,
338
+ res_hidden_states_tuple=res_samples,
339
+ encoder_hidden_states=encoder_hidden_states,
340
+ upsample_size=upsample_size,
341
+ )
342
+ attn_up.append(cross_atten_prob)
343
+ else:
344
+ sample = upsample_block(
345
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
346
+ )
347
+ # 6. post-process
348
+ sample = self.conv_norm_out(sample)
349
+ sample = self.conv_act(sample)
350
+ sample = self.conv_out(sample)
351
+
352
+ if not return_dict:
353
+ return (sample,)
354
+
355
+ return UNet2DConditionOutput(sample=sample), attn_up, attn_mid, attn_down
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.13.1
2
+ torchvision==0.14.1
3
+ omegaconf==2.2.3
4
+ opencv-python
5
+ imageio==2.9.0
6
+ transformers==4.24.0
7
+ diffusers==0.7.2
8
+ accelerate==0.13.2
9
+ scipy==1.9.1
10
+ # git+https://github.com/openai/CLIP.git
11
+ hydra-core==1.2.0
12
+ tqdm
13
+ gradio==3.23.0
utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ def compute_ca_loss(attn_maps_mid, attn_maps_up, bboxes, object_positions):
4
+ loss = 0
5
+ object_number = len(bboxes)
6
+ if object_number == 0:
7
+ return torch.tensor(0).float().cuda()
8
+ for attn_map_integrated in attn_maps_mid:
9
+ attn_map = attn_map_integrated.chunk(2)[1]
10
+
11
+ #
12
+ b, i, j = attn_map.shape
13
+ H = W = int(math.sqrt(i))
14
+ for obj_idx in range(object_number):
15
+ obj_loss = 0
16
+ mask = torch.zeros(size=(H, W)).cuda()
17
+ for obj_box in bboxes[obj_idx]:
18
+
19
+ x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
20
+ int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
21
+ mask[y_min: y_max, x_min: x_max] = 1
22
+
23
+ for obj_position in object_positions[obj_idx]:
24
+ ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W)
25
+
26
+ activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1)/ca_map_obj.reshape(b, -1).sum(dim=-1)
27
+
28
+ obj_loss += torch.mean((1 - activation_value) ** 2)
29
+ loss += (obj_loss/len(object_positions[obj_idx]))
30
+
31
+ # compute loss on padding tokens
32
+ # activation_value = torch.zeros(size=(b, )).cuda()
33
+ # for obj_idx in range(object_number):
34
+ # bbox = bboxes[obj_idx]
35
+ # ca_map_obj = attn_map[:, :, padding_start:].reshape(b, H, W, -1)
36
+ # activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H),
37
+ # int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1)
38
+ #
39
+ # loss += torch.mean((1 - activation_value) ** 2)
40
+
41
+
42
+ for attn_map_integrated in attn_maps_up[0]:
43
+ attn_map = attn_map_integrated.chunk(2)[1]
44
+ #
45
+ b, i, j = attn_map.shape
46
+ H = W = int(math.sqrt(i))
47
+
48
+ for obj_idx in range(object_number):
49
+ obj_loss = 0
50
+ mask = torch.zeros(size=(H, W)).cuda()
51
+ for obj_box in bboxes[obj_idx]:
52
+ x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
53
+ int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
54
+ mask[y_min: y_max, x_min: x_max] = 1
55
+
56
+ for obj_position in object_positions[obj_idx]:
57
+ ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W)
58
+ # ca_map_obj = attn_map[:, :, object_positions[obj_position]].reshape(b, H, W)
59
+
60
+ activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(
61
+ dim=-1)
62
+
63
+ obj_loss += torch.mean((1 - activation_value) ** 2)
64
+ loss += (obj_loss / len(object_positions[obj_idx]))
65
+
66
+ # compute loss on padding tokens
67
+ # activation_value = torch.zeros(size=(b, )).cuda()
68
+ # for obj_idx in range(object_number):
69
+ # bbox = bboxes[obj_idx]
70
+ # ca_map_obj = attn_map[:, :,padding_start:].reshape(b, H, W, -1)
71
+ # activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H),
72
+ # int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1)
73
+ #
74
+ # loss += torch.mean((1 - activation_value) ** 2)
75
+ loss = loss / (object_number * (len(attn_maps_up[0]) + len(attn_maps_mid)))
76
+ return loss