ZehanWang commited on
Commit
ebca029
1 Parent(s): 063e9a7

Upload folder using huggingface_hub

Browse files
marigold/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+ from .marigold_pipeline import MarigoldPipeline, MarigoldDepthOutput # noqa: F401
21
+ from .duplicate_unet import DoubleUNet2DConditionModel
marigold/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (294 Bytes). View file
 
marigold/__pycache__/duplicate_unet.cpython-310.pyc ADDED
Binary file (36.4 kB). View file
 
marigold/__pycache__/marigold_inpaint_pipeline.cpython-310.pyc ADDED
Binary file (24.5 kB). View file
 
marigold/__pycache__/marigold_pipeline.cpython-310.pyc ADDED
Binary file (28 kB). View file
 
marigold/__pycache__/marigold_xl_pipeline.cpython-310.pyc ADDED
Binary file (26.5 kB). View file
 
marigold/duplicate_unet.py ADDED
@@ -0,0 +1,1193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import pdb
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+ import copy
22
+
23
+ import peft
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.loaders import UNet2DConditionLoadersMixin
26
+ from diffusers.utils import BaseOutput, logging
27
+ from diffusers.models.activations import get_activation
28
+ from diffusers.models.attention_processor import (
29
+ ADDED_KV_ATTENTION_PROCESSORS,
30
+ CROSS_ATTENTION_PROCESSORS,
31
+ AttentionProcessor,
32
+ AttnAddedKVProcessor,
33
+ AttnProcessor,
34
+ )
35
+ from diffusers.models.embeddings import (
36
+ GaussianFourierProjection,
37
+ ImageHintTimeEmbedding,
38
+ ImageProjection,
39
+ ImageTimeEmbedding,
40
+ TextImageProjection,
41
+ TextImageTimeEmbedding,
42
+ TextTimeEmbedding,
43
+ TimestepEmbedding,
44
+ Timesteps,
45
+ )
46
+ from diffusers.models.modeling_utils import ModelMixin
47
+ from diffusers.models.unet_2d_blocks import (
48
+ UNetMidBlock2DCrossAttn,
49
+ UNetMidBlock2DSimpleCrossAttn,
50
+ get_down_block,
51
+ get_up_block,
52
+ )
53
+ from diffusers.models import UNet2DConditionModel
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+
58
+ @dataclass
59
+ class UNet2DConditionOutput(BaseOutput):
60
+ """
61
+ The output of [`UNet2DConditionModel`].
62
+
63
+ Args:
64
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
65
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
66
+ """
67
+
68
+ sample: torch.FloatTensor = None
69
+
70
+
71
+ class DoubleUNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
72
+ r"""
73
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
74
+ shaped output.
75
+
76
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
77
+ for all models (such as downloading or saving).
78
+
79
+ Parameters:
80
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
81
+ Height and width of input/output sample.
82
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
83
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
84
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
85
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
86
+ Whether to flip the sin to cos in the time embedding.
87
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
88
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
89
+ The tuple of downsample blocks to use.
90
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
91
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
92
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
93
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
94
+ The tuple of upsample blocks to use.
95
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
96
+ Whether to include self-attention in the basic transformer blocks, see
97
+ [`~models.attention.BasicTransformerBlock`].
98
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
99
+ The tuple of output channels for each block.
100
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
101
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
102
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
103
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
104
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
105
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
106
+ If `None`, normalization and activation layers is skipped in post-processing.
107
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
108
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
109
+ The dimension of the cross attention features.
110
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
111
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
112
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
113
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
114
+ encoder_hid_dim (`int`, *optional*, defaults to None):
115
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
116
+ dimension to `cross_attention_dim`.
117
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
118
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
119
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
120
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
121
+ num_attention_heads (`int`, *optional*):
122
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
123
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
124
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
125
+ class_embed_type (`str`, *optional*, defaults to `None`):
126
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
127
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
128
+ addition_embed_type (`str`, *optional*, defaults to `None`):
129
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
130
+ "text". "text" will use the `TextTimeEmbedding` layer.
131
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
132
+ Dimension for the timestep embeddings.
133
+ num_class_embeds (`int`, *optional*, defaults to `None`):
134
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
135
+ class conditioning with `class_embed_type` equal to `None`.
136
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
137
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
138
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
139
+ An optional override for the dimension of the projected time embedding.
140
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
141
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
142
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
143
+ timestep_post_act (`str`, *optional*, defaults to `None`):
144
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
145
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
146
+ The dimension of `cond_proj` layer in the timestep embedding.
147
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
148
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
149
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
150
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
151
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
152
+ embeddings with the class embeddings.
153
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
154
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
155
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
156
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
157
+ otherwise.
158
+ """
159
+
160
+ _supports_gradient_checkpointing = True
161
+
162
+ @register_to_config
163
+ def __init__(
164
+ self,
165
+ sample_size: Optional[int] = None,
166
+ in_channels: int = 4,
167
+ out_channels: int = 4,
168
+ center_input_sample: bool = False,
169
+ flip_sin_to_cos: bool = True,
170
+ freq_shift: int = 0,
171
+ down_block_types: Tuple[str] = (
172
+ "CrossAttnDownBlock2D",
173
+ "CrossAttnDownBlock2D",
174
+ "CrossAttnDownBlock2D",
175
+ "DownBlock2D",
176
+ ),
177
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
178
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
179
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
180
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
181
+ layers_per_block: Union[int, Tuple[int]] = 2,
182
+ downsample_padding: int = 1,
183
+ mid_block_scale_factor: float = 1,
184
+ dropout: float = 0.0,
185
+ act_fn: str = "silu",
186
+ norm_num_groups: Optional[int] = 32,
187
+ norm_eps: float = 1e-5,
188
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
189
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
190
+ encoder_hid_dim: Optional[int] = None,
191
+ encoder_hid_dim_type: Optional[str] = None,
192
+ attention_head_dim: Union[int, Tuple[int]] = 8,
193
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
194
+ dual_cross_attention: bool = False,
195
+ use_linear_projection: bool = False,
196
+ class_embed_type: Optional[str] = None,
197
+ addition_embed_type: Optional[str] = None,
198
+ addition_time_embed_dim: Optional[int] = None,
199
+ num_class_embeds: Optional[int] = None,
200
+ upcast_attention: bool = False,
201
+ resnet_time_scale_shift: str = "default",
202
+ resnet_skip_time_act: bool = False,
203
+ resnet_out_scale_factor: int = 1.0,
204
+ time_embedding_type: str = "positional",
205
+ time_embedding_dim: Optional[int] = None,
206
+ time_embedding_act_fn: Optional[str] = None,
207
+ timestep_post_act: Optional[str] = None,
208
+ time_cond_proj_dim: Optional[int] = None,
209
+ conv_in_kernel: int = 3,
210
+ conv_out_kernel: int = 3,
211
+ projection_class_embeddings_input_dim: Optional[int] = None,
212
+ attention_type: str = "default",
213
+ class_embeddings_concat: bool = False,
214
+ mid_block_only_cross_attention: Optional[bool] = None,
215
+ cross_attention_norm: Optional[str] = None,
216
+ addition_embed_type_num_heads=64,
217
+ ):
218
+ super().__init__()
219
+
220
+ self.sample_size = sample_size
221
+
222
+ if num_attention_heads is not None:
223
+ raise ValueError(
224
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
225
+ )
226
+
227
+ # If `num_attention_heads` is not defined (which is the case for most models)
228
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
229
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
230
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
231
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
232
+ # which is why we correct for the naming here.
233
+ num_attention_heads = num_attention_heads or attention_head_dim
234
+
235
+ # Check inputs
236
+ if len(down_block_types) != len(up_block_types):
237
+ raise ValueError(
238
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
239
+ )
240
+
241
+ if len(block_out_channels) != len(down_block_types):
242
+ raise ValueError(
243
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
244
+ )
245
+
246
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
247
+ raise ValueError(
248
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
249
+ )
250
+
251
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
252
+ raise ValueError(
253
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
254
+ )
255
+
256
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
257
+ raise ValueError(
258
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
259
+ )
260
+
261
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
262
+ raise ValueError(
263
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
264
+ )
265
+
266
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
267
+ raise ValueError(
268
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
269
+ )
270
+
271
+ # input
272
+ conv_in_padding = (conv_in_kernel - 1) // 2
273
+ self.conv_in = nn.Conv2d(
274
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
275
+ )
276
+
277
+ # time
278
+ if time_embedding_type == "fourier":
279
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
280
+ if time_embed_dim % 2 != 0:
281
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
282
+ self.time_proj = GaussianFourierProjection(
283
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
284
+ )
285
+ timestep_input_dim = time_embed_dim
286
+ elif time_embedding_type == "positional":
287
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
288
+
289
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
290
+ timestep_input_dim = block_out_channels[0]
291
+ else:
292
+ raise ValueError(
293
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
294
+ )
295
+
296
+ self.time_embedding = TimestepEmbedding(
297
+ timestep_input_dim,
298
+ time_embed_dim,
299
+ act_fn=act_fn,
300
+ post_act_fn=timestep_post_act,
301
+ cond_proj_dim=time_cond_proj_dim,
302
+ )
303
+
304
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
305
+ encoder_hid_dim_type = "text_proj"
306
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
307
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
308
+
309
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
310
+ raise ValueError(
311
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
312
+ )
313
+
314
+ if encoder_hid_dim_type == "text_proj":
315
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
316
+ elif encoder_hid_dim_type == "text_image_proj":
317
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
318
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
319
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
320
+ self.encoder_hid_proj = TextImageProjection(
321
+ text_embed_dim=encoder_hid_dim,
322
+ image_embed_dim=cross_attention_dim,
323
+ cross_attention_dim=cross_attention_dim,
324
+ )
325
+ elif encoder_hid_dim_type == "image_proj":
326
+ # Kandinsky 2.2
327
+ self.encoder_hid_proj = ImageProjection(
328
+ image_embed_dim=encoder_hid_dim,
329
+ cross_attention_dim=cross_attention_dim,
330
+ )
331
+ elif encoder_hid_dim_type is not None:
332
+ raise ValueError(
333
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
334
+ )
335
+ else:
336
+ self.encoder_hid_proj = None
337
+
338
+ # class embedding
339
+ if class_embed_type is None and num_class_embeds is not None:
340
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
341
+ elif class_embed_type == "timestep":
342
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
343
+ elif class_embed_type == "identity":
344
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
345
+ elif class_embed_type == "projection":
346
+ if projection_class_embeddings_input_dim is None:
347
+ raise ValueError(
348
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
349
+ )
350
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
351
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
352
+ # 2. it projects from an arbitrary input dimension.
353
+ #
354
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
355
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
356
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
357
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
358
+ elif class_embed_type == "simple_projection":
359
+ if projection_class_embeddings_input_dim is None:
360
+ raise ValueError(
361
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
362
+ )
363
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
364
+ else:
365
+ self.class_embedding = None
366
+
367
+ if addition_embed_type == "text":
368
+ if encoder_hid_dim is not None:
369
+ text_time_embedding_from_dim = encoder_hid_dim
370
+ else:
371
+ text_time_embedding_from_dim = cross_attention_dim
372
+
373
+ self.add_embedding = TextTimeEmbedding(
374
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
375
+ )
376
+ elif addition_embed_type == "text_image":
377
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
378
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
379
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
380
+ self.add_embedding = TextImageTimeEmbedding(
381
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
382
+ )
383
+ elif addition_embed_type == "text_time":
384
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
385
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
386
+ elif addition_embed_type == "image":
387
+ # Kandinsky 2.2
388
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
389
+ elif addition_embed_type == "image_hint":
390
+ # Kandinsky 2.2 ControlNet
391
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
392
+ elif addition_embed_type is not None:
393
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
394
+
395
+ if time_embedding_act_fn is None:
396
+ self.time_embed_act = None
397
+ else:
398
+ self.time_embed_act = get_activation(time_embedding_act_fn)
399
+
400
+ self.down_blocks = nn.ModuleList([])
401
+ self.up_blocks = nn.ModuleList([])
402
+
403
+ if isinstance(only_cross_attention, bool):
404
+ if mid_block_only_cross_attention is None:
405
+ mid_block_only_cross_attention = only_cross_attention
406
+
407
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
408
+
409
+ if mid_block_only_cross_attention is None:
410
+ mid_block_only_cross_attention = False
411
+
412
+ if isinstance(num_attention_heads, int):
413
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
414
+
415
+ if isinstance(attention_head_dim, int):
416
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
417
+
418
+ if isinstance(cross_attention_dim, int):
419
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
420
+
421
+ if isinstance(layers_per_block, int):
422
+ layers_per_block = [layers_per_block] * len(down_block_types)
423
+
424
+ if isinstance(transformer_layers_per_block, int):
425
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
426
+
427
+ if class_embeddings_concat:
428
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
429
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
430
+ # regular time embeddings
431
+ blocks_time_embed_dim = time_embed_dim * 2
432
+ else:
433
+ blocks_time_embed_dim = time_embed_dim
434
+
435
+ # interact layer:
436
+ self.down_rgb2depth = nn.ModuleList([])
437
+ self.down_depth2rgb = nn.ModuleList([])
438
+
439
+ # down
440
+ output_channel = block_out_channels[0]
441
+ for i, down_block_type in enumerate(down_block_types):
442
+ input_channel = output_channel
443
+ output_channel = block_out_channels[i]
444
+ is_final_block = i == len(block_out_channels) - 1
445
+
446
+ down_block = get_down_block(
447
+ down_block_type,
448
+ num_layers=layers_per_block[i],
449
+ transformer_layers_per_block=transformer_layers_per_block[i],
450
+ in_channels=input_channel,
451
+ out_channels=output_channel,
452
+ temb_channels=blocks_time_embed_dim,
453
+ add_downsample=not is_final_block,
454
+ resnet_eps=norm_eps,
455
+ resnet_act_fn=act_fn,
456
+ resnet_groups=norm_num_groups,
457
+ cross_attention_dim=cross_attention_dim[i],
458
+ num_attention_heads=num_attention_heads[i],
459
+ downsample_padding=downsample_padding,
460
+ dual_cross_attention=dual_cross_attention,
461
+ use_linear_projection=use_linear_projection,
462
+ only_cross_attention=only_cross_attention[i],
463
+ upcast_attention=upcast_attention,
464
+ resnet_time_scale_shift=resnet_time_scale_shift,
465
+ attention_type=attention_type,
466
+ resnet_skip_time_act=resnet_skip_time_act,
467
+ resnet_out_scale_factor=resnet_out_scale_factor,
468
+ cross_attention_norm=cross_attention_norm,
469
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
470
+ dropout=dropout,
471
+ )
472
+ self.down_blocks.append(down_block)
473
+
474
+ rgb2depth_block = nn.Conv2d(input_channel, input_channel, kernel_size=1)
475
+ rgb2depth_block = self.zero_module(rgb2depth_block)
476
+ self.down_rgb2depth.append(rgb2depth_block)
477
+ depth2rgb_block = nn.Conv2d(input_channel, input_channel, kernel_size=1)
478
+ depth2rgb_block = self.zero_module(depth2rgb_block)
479
+ self.down_depth2rgb.append(depth2rgb_block)
480
+
481
+ for _ in range(layers_per_block[i]):
482
+ rgb2depth_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
483
+ rgb2depth_block = self.zero_module(rgb2depth_block)
484
+ self.down_rgb2depth.append(rgb2depth_block)
485
+ depth2rgb_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
486
+ depth2rgb_block = self.zero_module(depth2rgb_block)
487
+ self.down_depth2rgb.append(depth2rgb_block)
488
+ #
489
+ # if not is_final_block:
490
+ # rgb2depth_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
491
+ # rgb2depth_block = self.zero_module(rgb2depth_block)
492
+ # self.down_rgb2depth.append(rgb2depth_block)
493
+ # depth2rgb_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
494
+ # depth2rgb_block = self.zero_module(depth2rgb_block)
495
+ # self.down_depth2rgb.append(depth2rgb_block)
496
+
497
+
498
+ mid_block_channel = block_out_channels[-1]
499
+
500
+ rgb2depth_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
501
+ rgb2depth_block = self.zero_module(rgb2depth_block)
502
+ self.mid_rgb2depth = rgb2depth_block
503
+
504
+ depth2rgb_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
505
+ depth2rgb_block = self.zero_module(depth2rgb_block)
506
+ self.mid_depth2rgb = depth2rgb_block
507
+
508
+ # mid
509
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
510
+ self.mid_block = UNetMidBlock2DCrossAttn(
511
+ transformer_layers_per_block=transformer_layers_per_block[-1],
512
+ in_channels=block_out_channels[-1],
513
+ temb_channels=blocks_time_embed_dim,
514
+ dropout=dropout,
515
+ resnet_eps=norm_eps,
516
+ resnet_act_fn=act_fn,
517
+ output_scale_factor=mid_block_scale_factor,
518
+ resnet_time_scale_shift=resnet_time_scale_shift,
519
+ cross_attention_dim=cross_attention_dim[-1],
520
+ num_attention_heads=num_attention_heads[-1],
521
+ resnet_groups=norm_num_groups,
522
+ dual_cross_attention=dual_cross_attention,
523
+ use_linear_projection=use_linear_projection,
524
+ upcast_attention=upcast_attention,
525
+ attention_type=attention_type,
526
+ )
527
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
528
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
529
+ in_channels=block_out_channels[-1],
530
+ temb_channels=blocks_time_embed_dim,
531
+ dropout=dropout,
532
+ resnet_eps=norm_eps,
533
+ resnet_act_fn=act_fn,
534
+ output_scale_factor=mid_block_scale_factor,
535
+ cross_attention_dim=cross_attention_dim[-1],
536
+ attention_head_dim=attention_head_dim[-1],
537
+ resnet_groups=norm_num_groups,
538
+ resnet_time_scale_shift=resnet_time_scale_shift,
539
+ skip_time_act=resnet_skip_time_act,
540
+ only_cross_attention=mid_block_only_cross_attention,
541
+ cross_attention_norm=cross_attention_norm,
542
+ )
543
+ elif mid_block_type is None:
544
+ self.mid_block = None
545
+ else:
546
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
547
+
548
+ # count how many layers upsample the images
549
+ self.num_upsamplers = 0
550
+
551
+ # up
552
+ reversed_block_out_channels = list(reversed(block_out_channels))
553
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
554
+ reversed_layers_per_block = list(reversed(layers_per_block))
555
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
556
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
557
+ only_cross_attention = list(reversed(only_cross_attention))
558
+
559
+ output_channel = reversed_block_out_channels[0]
560
+
561
+ for i, up_block_type in enumerate(up_block_types):
562
+ is_final_block = i == len(block_out_channels) - 1
563
+
564
+ prev_output_channel = output_channel
565
+ output_channel = reversed_block_out_channels[i]
566
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
567
+
568
+ # add upsample block for all BUT final layer
569
+ if not is_final_block:
570
+ add_upsample = True
571
+ self.num_upsamplers += 1
572
+ else:
573
+ add_upsample = False
574
+
575
+ up_block = get_up_block(
576
+ up_block_type,
577
+ num_layers=reversed_layers_per_block[i] + 1,
578
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
579
+ in_channels=input_channel,
580
+ out_channels=output_channel,
581
+ prev_output_channel=prev_output_channel,
582
+ temb_channels=blocks_time_embed_dim,
583
+ add_upsample=add_upsample,
584
+ resnet_eps=norm_eps,
585
+ resnet_act_fn=act_fn,
586
+ resolution_idx=i,
587
+ resnet_groups=norm_num_groups,
588
+ cross_attention_dim=reversed_cross_attention_dim[i],
589
+ num_attention_heads=reversed_num_attention_heads[i],
590
+ dual_cross_attention=dual_cross_attention,
591
+ use_linear_projection=use_linear_projection,
592
+ only_cross_attention=only_cross_attention[i],
593
+ upcast_attention=upcast_attention,
594
+ resnet_time_scale_shift=resnet_time_scale_shift,
595
+ attention_type=attention_type,
596
+ resnet_skip_time_act=resnet_skip_time_act,
597
+ resnet_out_scale_factor=resnet_out_scale_factor,
598
+ cross_attention_norm=cross_attention_norm,
599
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
600
+ dropout=dropout,
601
+ )
602
+ self.up_blocks.append(up_block)
603
+ prev_output_channel = output_channel
604
+
605
+ # out
606
+ if norm_num_groups is not None:
607
+ self.conv_norm_out = nn.GroupNorm(
608
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
609
+ )
610
+
611
+ self.conv_act = get_activation(act_fn)
612
+
613
+ else:
614
+ self.conv_norm_out = None
615
+ self.conv_act = None
616
+
617
+ conv_out_padding = (conv_out_kernel - 1) // 2
618
+ self.conv_out = nn.Conv2d(
619
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
620
+ )
621
+ self.separate_list = None
622
+ self.rgb_conv_in_double = 0
623
+ self.depth_conv_in_double = 0
624
+
625
+ def inpaint_rgb_conv_in(self): # replace the first layer to accept 13 in_channels
626
+ _n_convin_out_channel = self.conv_in.out_channels
627
+ _new_conv_in = nn.Conv2d(13, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
628
+ self.conv_in = _new_conv_in
629
+
630
+ print("Unet rgb conv_in layer is replaced by 13 conv_in channel")
631
+ return
632
+
633
+ def inpaint_depth_conv_in(self): # replace the first layer to accept 13 in_channels
634
+ _n_convin_out_channel = self.depth_conv_in.out_channels
635
+ _new_conv_in = nn.Conv2d(13, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
636
+ self.depth_conv_in = _new_conv_in
637
+ print("Unet depth conv_in layer is replaced by 13 conv_in channel")
638
+ return
639
+
640
+ def duplicate_model(self):
641
+ self.depth_time_embedding = copy.deepcopy(self.time_embedding)
642
+ self.depth_time_proj = copy.deepcopy(self.time_proj)
643
+
644
+ self.depth_conv_in = copy.deepcopy(self.conv_in)
645
+ self.depth_conv_norm_out = copy.deepcopy(self.conv_norm_out)
646
+ self.depth_conv_act = copy.deepcopy(self.conv_act)
647
+ self.depth_conv_out = copy.deepcopy(self.conv_out)
648
+
649
+ self.depth_down_blocks = nn.ModuleList()
650
+ self.depth_up_blocks = nn.ModuleList()
651
+
652
+ for i in range(len(self.down_blocks)):
653
+ self.depth_down_blocks.append(copy.deepcopy(self.down_blocks[i]))
654
+ for i in range(len(self.up_blocks)):
655
+ self.depth_up_blocks.append(copy.deepcopy(self.up_blocks[i]))
656
+
657
+ self.depth_mid_block = copy.deepcopy(self.mid_block)
658
+
659
+ return
660
+
661
+ def zero_module(self, module):
662
+ for p in module.parameters():
663
+ nn.init.zeros_(p)
664
+ return module
665
+
666
+ @property
667
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
668
+ r"""
669
+ Returns:
670
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
671
+ indexed by its weight name.
672
+ """
673
+ # set recursively
674
+ processors = {}
675
+
676
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
677
+ if hasattr(module, "get_processor"):
678
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
679
+
680
+ for sub_name, child in module.named_children():
681
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
682
+
683
+ return processors
684
+
685
+ for name, module in self.named_children():
686
+ fn_recursive_add_processors(name, module, processors)
687
+
688
+ return processors
689
+
690
+ def set_attn_processor(
691
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
692
+ ):
693
+ r"""
694
+ Sets the attention processor to use to compute attention.
695
+
696
+ Parameters:
697
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
698
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
699
+ for **all** `Attention` layers.
700
+
701
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
702
+ processor. This is strongly recommended when setting trainable attention processors.
703
+
704
+ """
705
+ count = len(self.attn_processors.keys())
706
+
707
+ if isinstance(processor, dict) and len(processor) != count:
708
+ raise ValueError(
709
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
710
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
711
+ )
712
+
713
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
714
+ if hasattr(module, "set_processor"):
715
+ if not isinstance(processor, dict):
716
+ module.set_processor(processor, _remove_lora=_remove_lora)
717
+ else:
718
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
719
+
720
+ for sub_name, child in module.named_children():
721
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
722
+
723
+ for name, module in self.named_children():
724
+ fn_recursive_attn_processor(name, module, processor)
725
+
726
+ def set_default_attn_processor(self):
727
+ """
728
+ Disables custom attention processors and sets the default attention implementation.
729
+ """
730
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
731
+ processor = AttnAddedKVProcessor()
732
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
733
+ processor = AttnProcessor()
734
+ else:
735
+ raise ValueError(
736
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
737
+ )
738
+
739
+ self.set_attn_processor(processor, _remove_lora=True)
740
+
741
+ def set_attention_slice(self, slice_size):
742
+ r"""
743
+ Enable sliced attention computation.
744
+
745
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
746
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
747
+
748
+ Args:
749
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
750
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
751
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
752
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
753
+ must be a multiple of `slice_size`.
754
+ """
755
+ sliceable_head_dims = []
756
+
757
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
758
+ if hasattr(module, "set_attention_slice"):
759
+ sliceable_head_dims.append(module.sliceable_head_dim)
760
+
761
+ for child in module.children():
762
+ fn_recursive_retrieve_sliceable_dims(child)
763
+
764
+ # retrieve number of attention layers
765
+ for module in self.children():
766
+ fn_recursive_retrieve_sliceable_dims(module)
767
+
768
+ num_sliceable_layers = len(sliceable_head_dims)
769
+
770
+ if slice_size == "auto":
771
+ # half the attention head size is usually a good trade-off between
772
+ # speed and memory
773
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
774
+ elif slice_size == "max":
775
+ # make smallest slice possible
776
+ slice_size = num_sliceable_layers * [1]
777
+
778
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
779
+
780
+ if len(slice_size) != len(sliceable_head_dims):
781
+ raise ValueError(
782
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
783
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
784
+ )
785
+
786
+ for i in range(len(slice_size)):
787
+ size = slice_size[i]
788
+ dim = sliceable_head_dims[i]
789
+ if size is not None and size > dim:
790
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
791
+
792
+ # Recursively walk through all the children.
793
+ # Any children which exposes the set_attention_slice method
794
+ # gets the message
795
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
796
+ if hasattr(module, "set_attention_slice"):
797
+ module.set_attention_slice(slice_size.pop())
798
+
799
+ for child in module.children():
800
+ fn_recursive_set_attention_slice(child, slice_size)
801
+
802
+ reversed_slice_size = list(reversed(slice_size))
803
+ for module in self.children():
804
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
805
+
806
+ def _set_gradient_checkpointing(self, module, value=False):
807
+ if hasattr(module, "gradient_checkpointing"):
808
+ module.gradient_checkpointing = value
809
+
810
+ def enable_freeu(self, s1, s2, b1, b2):
811
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
812
+
813
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
814
+
815
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
816
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
817
+
818
+ Args:
819
+ s1 (`float`):
820
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
821
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
822
+ s2 (`float`):
823
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
824
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
825
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
826
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
827
+ """
828
+ for i, upsample_block in enumerate(self.up_blocks):
829
+ setattr(upsample_block, "s1", s1)
830
+ setattr(upsample_block, "s2", s2)
831
+ setattr(upsample_block, "b1", b1)
832
+ setattr(upsample_block, "b2", b2)
833
+
834
+ def disable_freeu(self):
835
+ """Disables the FreeU mechanism."""
836
+ freeu_keys = {"s1", "s2", "b1", "b2"}
837
+ for i, upsample_block in enumerate(self.up_blocks):
838
+ for k in freeu_keys:
839
+ if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None:
840
+ setattr(upsample_block, k, None)
841
+
842
+ def forward(
843
+ self,
844
+ sample: torch.FloatTensor,
845
+ # timestep: Union[torch.Tensor, float, int],
846
+ rgb_timestep: Union[torch.Tensor, float, int],
847
+ depth_timestep: Union[torch.Tensor, float, int],
848
+ encoder_hidden_states: torch.Tensor,
849
+ controlnet_connection: bool = True,
850
+ class_labels: Optional[torch.Tensor] = None,
851
+ timestep_cond: Optional[torch.Tensor] = None,
852
+ attention_mask: Optional[torch.Tensor] = None,
853
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
854
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
855
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
856
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
857
+ encoder_attention_mask: Optional[torch.Tensor] = None,
858
+ return_dict: bool = True,
859
+ depth2rgb_scale: float = 1.,
860
+ rgb2depth_scale: float = 1.
861
+ ) -> Union[UNet2DConditionOutput, Tuple]:
862
+ r"""
863
+ The [`UNet2DConditionModel`] forward method.
864
+
865
+ Args:
866
+ sample (`torch.FloatTensor`):
867
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
868
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
869
+ encoder_hidden_states (`torch.FloatTensor`):
870
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
871
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
872
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
873
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
874
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
875
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
876
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
877
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
878
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
879
+ negative values to the attention scores corresponding to "discard" tokens.
880
+ cross_attention_kwargs (`dict`, *optional*):
881
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
882
+ `self.processor` in
883
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
884
+ added_cond_kwargs: (`dict`, *optional*):
885
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
886
+ are passed along to the UNet blocks.
887
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
888
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
889
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
890
+ A tensor that if specified is added to the residual of the middle unet block.
891
+ encoder_attention_mask (`torch.Tensor`):
892
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
893
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
894
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
895
+ return_dict (`bool`, *optional*, defaults to `True`):
896
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
897
+ tuple.
898
+ cross_attention_kwargs (`dict`, *optional*):
899
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
900
+ added_cond_kwargs: (`dict`, *optional*):
901
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
902
+ are passed along to the UNet blocks.
903
+
904
+ Returns:
905
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
906
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
907
+ a `tuple` is returned where the first element is the sample tensor.
908
+ """
909
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
910
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
911
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
912
+ # on the fly if necessary.
913
+ default_overall_up_factor = 2**self.num_upsamplers
914
+
915
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
916
+ forward_upsample_size = False
917
+ upsample_size = None
918
+
919
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
920
+ # Forward upsample size to force interpolation output size.
921
+ forward_upsample_size = True
922
+
923
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
924
+ # expects mask of shape:
925
+ # [batch, key_tokens]
926
+ # adds singleton query_tokens dimension:
927
+ # [batch, 1, key_tokens]
928
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
929
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
930
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
931
+ if attention_mask is not None:
932
+ # assume that mask is expressed as:
933
+ # (1 = keep, 0 = discard)
934
+ # convert mask into a bias that can be added to attention scores:
935
+ # (keep = +0, discard = -10000.0)
936
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
937
+ attention_mask = attention_mask.unsqueeze(1)
938
+
939
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
940
+ if encoder_attention_mask is not None:
941
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
942
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
943
+
944
+ # 0. center input if necessary
945
+ if self.config.center_input_sample:
946
+ sample = 2 * sample - 1.0
947
+
948
+ # 1. time
949
+ # timesteps = timestep
950
+
951
+ # for timestep in [rgb_timestep, depth_timestep]:
952
+ rgb_timesteps = rgb_timestep
953
+ depth_timesteps = depth_timestep
954
+ # timesteps = timestep
955
+ if not torch.is_tensor(rgb_timesteps):
956
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
957
+ # This would be a good case for the `match` statement (Python 3.10+)
958
+ is_mps = sample.device.type == "mps"
959
+ if isinstance(rgb_timestep, float):
960
+ dtype = torch.float32 if is_mps else torch.float64
961
+ else:
962
+ dtype = torch.int32 if is_mps else torch.int64
963
+ rgb_timesteps = torch.tensor([rgb_timesteps], dtype=dtype, device=sample.device)
964
+ elif len(rgb_timesteps.shape) == 0:
965
+ rgb_timesteps = rgb_timesteps[None].to(sample.device)
966
+ rgb_timesteps = rgb_timesteps.expand(sample.shape[0])
967
+ rgb_t_emb = self.time_proj(rgb_timesteps)
968
+ rgb_t_emb = rgb_t_emb.to(dtype=sample.dtype)
969
+ rgb_emb = self.time_embedding(rgb_t_emb, timestep_cond)
970
+
971
+ if not torch.is_tensor(depth_timesteps):
972
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
973
+ # This would be a good case for the `match` statement (Python 3.10+)
974
+ is_mps = sample.device.type == "mps"
975
+ if isinstance(depth_timestep, float):
976
+ dtype = torch.float32 if is_mps else torch.float64
977
+ else:
978
+ dtype = torch.int32 if is_mps else torch.int64
979
+ depth_timesteps = torch.tensor([depth_timesteps], dtype=dtype, device=sample.device)
980
+ elif len(depth_timesteps.shape) == 0:
981
+ depth_timesteps = depth_timesteps[None].to(sample.device)
982
+ depth_timesteps = depth_timesteps.expand(sample.shape[0])
983
+ depth_t_emb = self.depth_time_proj(depth_timesteps)
984
+ depth_t_emb = depth_t_emb.to(dtype=sample.dtype)
985
+ depth_emb = self.depth_time_embedding(depth_t_emb, timestep_cond)
986
+ aug_emb = None
987
+
988
+ rgb_emb = rgb_emb + aug_emb if aug_emb is not None else rgb_emb
989
+ depth_emb = depth_emb + aug_emb if aug_emb is not None else depth_emb
990
+
991
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
992
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
993
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
994
+ # Kadinsky 2.1 - style
995
+ if "image_embeds" not in added_cond_kwargs:
996
+ raise ValueError(
997
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
998
+ )
999
+ image_embeds = added_cond_kwargs.get("image_embeds")
1000
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1001
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1002
+ # Kandinsky 2.2 - style
1003
+ if "image_embeds" not in added_cond_kwargs:
1004
+ raise ValueError(
1005
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1006
+ )
1007
+ image_embeds = added_cond_kwargs.get("image_embeds")
1008
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1009
+
1010
+ # 2. pre-process
1011
+ rgb_sample, depth_sample = sample.chunk(2, dim=1)
1012
+ depth_sample = self.depth_conv_in(depth_sample)
1013
+ sample = self.conv_in(rgb_sample)
1014
+
1015
+ # 2.5 GLIGEN position net
1016
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1017
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1018
+ gligen_args = cross_attention_kwargs.pop("gligen")
1019
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1020
+
1021
+ # 3. down
1022
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1023
+
1024
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1025
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
1026
+
1027
+ down_block_res_depth_samples = (depth_sample,)
1028
+ down_block_res_samples = (sample,)
1029
+
1030
+ for block_id, downsample_block in enumerate(self.down_blocks):
1031
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1032
+ # For t2i-adapter CrossAttnDownBlock2D
1033
+ additional_residuals = {}
1034
+ sample, res_samples = downsample_block(
1035
+ hidden_states=sample,
1036
+ temb=rgb_emb,
1037
+ encoder_hidden_states=encoder_hidden_states,
1038
+ attention_mask=attention_mask,
1039
+ cross_attention_kwargs=cross_attention_kwargs,
1040
+ encoder_attention_mask=encoder_attention_mask,
1041
+ **additional_residuals,
1042
+ )
1043
+
1044
+ # depth_res_samples = res_samples
1045
+ # if separate_list is not None and block_id < separate_list[0]:
1046
+
1047
+ depth_sample, depth_res_samples = self.depth_down_blocks[block_id](
1048
+ hidden_states=depth_sample,
1049
+ temb=depth_emb,
1050
+ encoder_hidden_states=encoder_hidden_states,
1051
+ attention_mask=attention_mask,
1052
+ cross_attention_kwargs=cross_attention_kwargs,
1053
+ encoder_attention_mask=encoder_attention_mask,
1054
+ **additional_residuals,
1055
+ )
1056
+ # mean_sample = (depth_sample + sample) / 2
1057
+ # depth_sample, sample = mean_sample, mean_sample
1058
+ else:
1059
+ sample, res_samples = downsample_block(hidden_states=sample, temb=rgb_emb, scale=lora_scale)
1060
+
1061
+ # depth_res_samples = res_samples
1062
+ # if separate_list is not None and block_id < separate_list[0]:
1063
+ if isinstance(self.depth_down_blocks[block_id], torch.nn.Linear):
1064
+ depth_sample, depth_res_samples = downsample_block(hidden_states=depth_sample, temb=depth_emb, scale=lora_scale)
1065
+ else:
1066
+ depth_sample, depth_res_samples = self.depth_down_blocks[block_id](hidden_states=depth_sample, temb=depth_emb, scale=lora_scale)
1067
+ # mean_sample = (depth_sample + sample) / 2
1068
+ # depth_sample, sample = mean_sample, mean_sample
1069
+
1070
+ if is_adapter and len(down_block_additional_residuals) > 0:
1071
+ sample += down_block_additional_residuals.pop(0)
1072
+
1073
+ down_block_res_samples += res_samples
1074
+ down_block_res_depth_samples += depth_res_samples
1075
+
1076
+ if controlnet_connection:
1077
+ new_down_block_res_samples = ()
1078
+ new_down_block_res_depth_samples = ()
1079
+ for down_block_res_sample, down_block_res_depth_sample, rgb2depth_block, depth2rgb_block in zip(
1080
+ down_block_res_samples, down_block_res_depth_samples, self.down_rgb2depth, self.down_depth2rgb
1081
+ ):
1082
+ new_down_block_res_sample = down_block_res_sample + depth2rgb_scale * depth2rgb_block(down_block_res_depth_sample)
1083
+ new_down_block_res_samples = new_down_block_res_samples + (new_down_block_res_sample,)
1084
+
1085
+ new_down_block_res_depth_sample = down_block_res_depth_sample + rgb2depth_scale * rgb2depth_block(down_block_res_sample)
1086
+ new_down_block_res_depth_samples = new_down_block_res_depth_samples + (new_down_block_res_depth_sample,)
1087
+
1088
+ down_block_res_samples = new_down_block_res_samples
1089
+ down_block_res_depth_samples = new_down_block_res_depth_samples
1090
+
1091
+ from diffusers import ControlNetModel
1092
+ # 4. mid
1093
+ if self.mid_block is not None:
1094
+ sample = self.mid_block(
1095
+ sample,
1096
+ rgb_emb,
1097
+ encoder_hidden_states=encoder_hidden_states,
1098
+ attention_mask=attention_mask,
1099
+ cross_attention_kwargs=cross_attention_kwargs,
1100
+ encoder_attention_mask=encoder_attention_mask,
1101
+ )
1102
+ # if separate_list is not None and len(separate_list[0]) == 3:
1103
+
1104
+ depth_sample = self.depth_mid_block(
1105
+ depth_sample,
1106
+ depth_emb,
1107
+ encoder_hidden_states=encoder_hidden_states,
1108
+ attention_mask=attention_mask,
1109
+ cross_attention_kwargs=cross_attention_kwargs,
1110
+ encoder_attention_mask=encoder_attention_mask,
1111
+ )
1112
+
1113
+ if controlnet_connection:
1114
+ new_depth_sample = depth_sample + rgb2depth_scale * self.mid_rgb2depth(sample)
1115
+ new_image_sample = sample + depth2rgb_scale * self.mid_depth2rgb(depth_sample)
1116
+ depth_sample = new_depth_sample
1117
+ sample = new_image_sample
1118
+
1119
+ # 5. up
1120
+ for i, upsample_block in enumerate(self.up_blocks):
1121
+ rever_block_id = len(self.up_blocks) - i - 1
1122
+ is_final_block = i == len(self.up_blocks) - 1
1123
+
1124
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1125
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1126
+
1127
+ res_depth_samples = down_block_res_depth_samples[-len(upsample_block.resnets):]
1128
+ down_block_res_depth_samples = down_block_res_depth_samples[: -len(upsample_block.resnets)]
1129
+
1130
+ # if we have not reached the final block and need to forward the
1131
+ # upsample size, we do it here
1132
+ # if separate_list is not None and rever_block_id < separate_list[-1]:
1133
+
1134
+ if not is_final_block and forward_upsample_size:
1135
+ upsample_size = down_block_res_samples[-1].shape[2:]
1136
+
1137
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1138
+ sample = upsample_block(
1139
+ hidden_states=sample,
1140
+ temb=rgb_emb,
1141
+ res_hidden_states_tuple=res_samples,
1142
+ encoder_hidden_states=encoder_hidden_states,
1143
+ cross_attention_kwargs=cross_attention_kwargs,
1144
+ upsample_size=upsample_size,
1145
+ attention_mask=attention_mask,
1146
+ encoder_attention_mask=encoder_attention_mask,
1147
+ )
1148
+
1149
+ depth_sample = self.depth_up_blocks[i](
1150
+ hidden_states=depth_sample,
1151
+ temb=depth_emb,
1152
+ res_hidden_states_tuple=res_depth_samples,
1153
+ encoder_hidden_states=encoder_hidden_states,
1154
+ cross_attention_kwargs=cross_attention_kwargs,
1155
+ upsample_size=upsample_size,
1156
+ attention_mask=attention_mask,
1157
+ encoder_attention_mask=encoder_attention_mask,
1158
+ )
1159
+ # mean_sample = (depth_sample + sample) / 2
1160
+ # depth_sample, sample = mean_sample, mean_sample
1161
+ else:
1162
+ sample = upsample_block(
1163
+ hidden_states=sample,
1164
+ temb=rgb_emb,
1165
+ res_hidden_states_tuple=res_samples,
1166
+ upsample_size=upsample_size,
1167
+ scale=lora_scale,
1168
+ )
1169
+ # if separate_list is not None and rever_block_id < separate_list[-1]:
1170
+ depth_sample = self.depth_up_blocks[i](
1171
+ hidden_states=depth_sample,
1172
+ temb=depth_emb,
1173
+ res_hidden_states_tuple=res_depth_samples,
1174
+ upsample_size=upsample_size,
1175
+ scale=lora_scale,
1176
+ )
1177
+
1178
+ # 6. post-process
1179
+ if self.conv_norm_out:
1180
+ sample = self.conv_norm_out(sample)
1181
+ sample = self.conv_act(sample)
1182
+ sample = self.conv_out(sample)
1183
+
1184
+ if self.conv_norm_out:
1185
+ depth_sample = self.depth_conv_norm_out(depth_sample)
1186
+ depth_sample = self.depth_conv_act(depth_sample)
1187
+ depth_sample = self.depth_conv_out(depth_sample)
1188
+ sample = torch.cat([sample, depth_sample], dim=1)
1189
+
1190
+ if not return_dict:
1191
+ return (sample,)
1192
+
1193
+ return UNet2DConditionOutput(sample=sample)
marigold/marigold_inpaint_pipeline.py ADDED
@@ -0,0 +1,873 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ # Last modified: 2024-05-24
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # --------------------------------------------------------------------------
16
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
17
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
18
+ # More information about the method can be found at https://marigoldmonodepth.github.io
19
+ # --------------------------------------------------------------------------
20
+
21
+ import logging
22
+ from diffusers.image_processor import VaeImageProcessor
23
+ import pdb
24
+ from typing import Dict, Optional, Union
25
+ import PIL.Image
26
+ import numpy as np
27
+ import torch
28
+ from diffusers import (
29
+ AutoencoderKL,
30
+ DDIMScheduler,
31
+ DiffusionPipeline,
32
+ LCMScheduler,
33
+ PNDMScheduler,
34
+ UNet2DConditionModel,
35
+ )
36
+ from .duplicate_unet import DoubleUNet2DConditionModel
37
+ from torch.nn import Conv2d
38
+ from PIL import ImageDraw, ImageFont
39
+ from torch.nn.parameter import Parameter
40
+ from diffusers.utils import BaseOutput, make_image_grid
41
+ from PIL import Image
42
+ from torch.utils.data import DataLoader, TensorDataset
43
+ from torchvision.transforms import InterpolationMode
44
+ from torchvision.transforms.functional import pil_to_tensor, resize
45
+ from tqdm.auto import tqdm
46
+ from transformers import CLIPTextModel, CLIPTokenizer
47
+
48
+ from .util.batchsize import find_batch_size
49
+ from .util.ensemble import ensemble_depth
50
+ from .util.image_util import (
51
+ chw2hwc,
52
+ colorize_depth_maps,
53
+ get_tv_resample_method,
54
+ resize_max_res,
55
+ )
56
+
57
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
58
+ """
59
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
60
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
61
+ """
62
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
63
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
64
+ # rescale the results from guidance (fixes overexposure)
65
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
66
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
67
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
68
+ return noise_cfg
69
+
70
+ class MarigoldDepthOutput(BaseOutput):
71
+ """
72
+ Output class for Marigold monocular depth prediction pipeline.
73
+
74
+ Args:
75
+ depth_np (`np.ndarray`):
76
+ Predicted depth map, with depth values in the range of [0, 1].
77
+ depth_colored (`PIL.Image.Image`):
78
+ Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
79
+ uncertainty (`None` or `np.ndarray`):
80
+ Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
81
+ """
82
+
83
+ depth_np: np.ndarray
84
+ depth_colored: Union[None, Image.Image]
85
+ uncertainty: Union[None, np.ndarray]
86
+
87
+ class MarigoldInpaintPipeline(DiffusionPipeline):
88
+ """
89
+ Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
90
+
91
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
92
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
93
+
94
+ Args:
95
+ unet (`UNet2DConditionModel`):
96
+ Conditional U-Net to denoise the depth latent, conditioned on image latent.
97
+ vae (`AutoencoderKL`):
98
+ Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
99
+ to and from latent representations.
100
+ scheduler (`DDIMScheduler`):
101
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
102
+ text_encoder (`CLIPTextModel`):
103
+ Text-encoder, for empty text embedding.
104
+ tokenizer (`CLIPTokenizer`):
105
+ CLIP tokenizer.
106
+ scale_invariant (`bool`, *optional*):
107
+ A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in
108
+ the model config. When used together with the `shift_invariant=True` flag, the model is also called
109
+ "affine-invariant". NB: overriding this value is not supported.
110
+ shift_invariant (`bool`, *optional*):
111
+ A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in
112
+ the model config. When used together with the `scale_invariant=True` flag, the model is also called
113
+ "affine-invariant". NB: overriding this value is not supported.
114
+ default_denoising_steps (`int`, *optional*):
115
+ The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
116
+ quality with the given model. This value must be set in the model config. When the pipeline is called
117
+ without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
118
+ reasonable results with various model flavors compatible with the pipeline, such as those relying on very
119
+ short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
120
+ default_processing_resolution (`int`, *optional*):
121
+ The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
122
+ the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
123
+ default value is used. This is required to ensure reasonable results with various model flavors trained
124
+ with varying optimal processing resolution values.
125
+ """
126
+
127
+ rgb_latent_scale_factor = 0.18215
128
+ depth_latent_scale_factor = 0.18215
129
+
130
+ def __init__(
131
+ self,
132
+ unet: DoubleUNet2DConditionModel,
133
+ vae: AutoencoderKL,
134
+ scheduler: Union[DDIMScheduler, LCMScheduler],
135
+ text_encoder: CLIPTextModel,
136
+ tokenizer: CLIPTokenizer,
137
+ scale_invariant: Optional[bool] = True,
138
+ shift_invariant: Optional[bool] = True,
139
+ default_denoising_steps: Optional[int] = None,
140
+ default_processing_resolution: Optional[int] = None,
141
+ requires_safety_checker: bool = False,
142
+ ):
143
+ super().__init__()
144
+
145
+ self.register_modules(
146
+ unet=unet,
147
+ vae=vae,
148
+ scheduler=scheduler,
149
+ text_encoder=text_encoder,
150
+ tokenizer=tokenizer,
151
+ )
152
+ self.register_to_config(
153
+ scale_invariant=scale_invariant,
154
+ shift_invariant=shift_invariant,
155
+ default_denoising_steps=default_denoising_steps,
156
+ default_processing_resolution=default_processing_resolution,
157
+ )
158
+
159
+ self.scale_invariant = scale_invariant
160
+ self.shift_invariant = shift_invariant
161
+ self.default_denoising_steps = default_denoising_steps
162
+ self.default_processing_resolution = default_processing_resolution
163
+ self.rgb_scheduler = None
164
+ self.depth_scheduler = None
165
+
166
+ self.empty_text_embed = None
167
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
168
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
169
+ self.mask_processor = VaeImageProcessor(
170
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
171
+ )
172
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
173
+ self.separate_list = [0,0]
174
+
175
+ @torch.no_grad()
176
+ def __call__(
177
+ self,
178
+ input_image: Union[Image.Image, torch.Tensor],
179
+ denoising_steps: Optional[int] = None,
180
+ ensemble_size: int = 5,
181
+ processing_res: Optional[int] = None,
182
+ match_input_res: bool = True,
183
+ resample_method: str = "bilinear",
184
+ batch_size: int = 0,
185
+ generator: Union[torch.Generator, None] = None,
186
+ color_map: str = "Spectral",
187
+ show_progress_bar: bool = True,
188
+ ensemble_kwargs: Dict = None,
189
+ ) -> MarigoldDepthOutput:
190
+ """
191
+ Function invoked when calling the pipeline.
192
+
193
+ Args:
194
+ input_image (`Image`):
195
+ Input RGB (or gray-scale) image.
196
+ denoising_steps (`int`, *optional*, defaults to `None`):
197
+ Number of denoising diffusion steps during inference. The default value `None` results in automatic
198
+ selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
199
+ for Marigold-LCM models.
200
+ ensemble_size (`int`, *optional*, defaults to `10`):
201
+ Number of predictions to be ensembled.
202
+ processing_res (`int`, *optional*, defaults to `None`):
203
+ Effective processing resolution. When set to `0`, processes at the original image resolution. This
204
+ produces crisper predictions, but may also lead to the overall loss of global context. The default
205
+ value `None` resolves to the optimal value from the model config.
206
+ match_input_res (`bool`, *optional*, defaults to `True`):
207
+ Resize depth prediction to match input resolution.
208
+ Only valid if `processing_res` > 0.
209
+ resample_method: (`str`, *optional*, defaults to `bilinear`):
210
+ Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
211
+ batch_size (`int`, *optional*, defaults to `0`):
212
+ Inference batch size, no bigger than `num_ensemble`.
213
+ If set to 0, the script will automatically decide the proper batch size.
214
+ generator (`torch.Generator`, *optional*, defaults to `None`)
215
+ Random generator for initial noise generation.
216
+ show_progress_bar (`bool`, *optional*, defaults to `True`):
217
+ Display a progress bar of diffusion denoising.
218
+ color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
219
+ Colormap used to colorize the depth map.
220
+ scale_invariant (`str`, *optional*, defaults to `True`):
221
+ Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction.
222
+ shift_invariant (`str`, *optional*, defaults to `True`):
223
+ Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False, near plane will be fixed at 0m.
224
+ ensemble_kwargs (`dict`, *optional*, defaults to `None`):
225
+ Arguments for detailed ensembling settings.
226
+ Returns:
227
+ `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
228
+ - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
229
+ - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
230
+ - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
231
+ coming from ensembling. None if `ensemble_size = 1`
232
+ """
233
+ # Model-specific optimal default values leading to fast and reasonable results.
234
+ if denoising_steps is None:
235
+ denoising_steps = self.default_denoising_steps
236
+ if processing_res is None:
237
+ processing_res = self.default_processing_resolution
238
+
239
+ assert processing_res >= 0
240
+ assert ensemble_size >= 1
241
+
242
+ # Check if denoising step is reasonable
243
+ self._check_inference_step(denoising_steps)
244
+
245
+ resample_method: InterpolationMode = get_tv_resample_method(resample_method)
246
+
247
+ # ----------------- Image Preprocess -----------------
248
+ # Convert to torch tensor
249
+ if isinstance(input_image, Image.Image):
250
+ input_image = input_image.convert("RGB")
251
+ # convert to torch tensor [H, W, rgb] -> [rgb, H, W]
252
+ rgb = pil_to_tensor(input_image)
253
+ rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
254
+ elif isinstance(input_image, torch.Tensor):
255
+ rgb = input_image
256
+ else:
257
+ raise TypeError(f"Unknown input type: {type(input_image) = }")
258
+ input_size = rgb.shape
259
+ assert (
260
+ 4 == rgb.dim() and 3 == input_size[-3]
261
+ ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
262
+
263
+ # Resize image
264
+ if processing_res > 0:
265
+ rgb = resize_max_res(
266
+ rgb,
267
+ max_edge_resolution=processing_res,
268
+ resample_method=resample_method,
269
+ )
270
+
271
+ # Normalize rgb values
272
+ rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
273
+ rgb_norm = rgb_norm.to(self.dtype)
274
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
275
+
276
+ # ----------------- Predicting depth -----------------
277
+ # Batch repeated input image
278
+ duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
279
+ single_rgb_dataset = TensorDataset(duplicated_rgb)
280
+ if batch_size > 0:
281
+ _bs = batch_size
282
+ else:
283
+ _bs = find_batch_size(
284
+ ensemble_size=ensemble_size,
285
+ input_res=max(rgb_norm.shape[1:]),
286
+ dtype=self.dtype,
287
+ )
288
+
289
+ single_rgb_loader = DataLoader(
290
+ single_rgb_dataset, batch_size=_bs, shuffle=False
291
+ )
292
+
293
+ # Predict depth maps (batched)
294
+ depth_pred_ls = []
295
+ if show_progress_bar:
296
+ iterable = tqdm(
297
+ single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
298
+ )
299
+ else:
300
+ iterable = single_rgb_loader
301
+ for batch in iterable:
302
+ (batched_img,) = batch
303
+ depth_pred_raw = self.single_infer(
304
+ rgb_in=batched_img,
305
+ num_inference_steps=denoising_steps,
306
+ show_pbar=show_progress_bar,
307
+ generator=generator,
308
+ )
309
+ depth_pred_ls.append(depth_pred_raw.detach())
310
+ depth_preds = torch.concat(depth_pred_ls, dim=0)
311
+ torch.cuda.empty_cache() # clear vram cache for ensembling
312
+
313
+ # ----------------- Test-time ensembling -----------------
314
+ if ensemble_size > 1:
315
+ depth_pred, pred_uncert = ensemble_depth(
316
+ depth_preds,
317
+ scale_invariant=self.scale_invariant,
318
+ shift_invariant=self.shift_invariant,
319
+ max_res=50,
320
+ **(ensemble_kwargs or {}),
321
+ )
322
+ else:
323
+ depth_pred = depth_preds
324
+ pred_uncert = None
325
+
326
+ # Resize back to original resolution
327
+ if match_input_res:
328
+ depth_pred = resize(
329
+ depth_pred,
330
+ input_size[-2:],
331
+ interpolation=resample_method,
332
+ antialias=True,
333
+ )
334
+
335
+ # Convert to numpy
336
+ depth_pred = depth_pred.squeeze()
337
+ depth_pred = depth_pred.cpu().numpy()
338
+ if pred_uncert is not None:
339
+ pred_uncert = pred_uncert.squeeze().cpu().numpy()
340
+
341
+ # Clip output range
342
+ depth_pred = depth_pred.clip(0, 1)
343
+
344
+ # Colorize
345
+ if color_map is not None:
346
+ depth_colored = colorize_depth_maps(
347
+ depth_pred, 0, 1, cmap=color_map
348
+ ).squeeze() # [3, H, W], value in (0, 1)
349
+ depth_colored = (depth_colored * 255).astype(np.uint8)
350
+ depth_colored_hwc = chw2hwc(depth_colored)
351
+ depth_colored_img = Image.fromarray(depth_colored_hwc)
352
+ else:
353
+ depth_colored_img = None
354
+
355
+ return MarigoldDepthOutput(
356
+ depth_np=depth_pred,
357
+ depth_colored=depth_colored_img,
358
+ uncertainty=pred_uncert,
359
+ )
360
+
361
+ def _replace_unet_conv_in(self):
362
+ # replace the first layer to accept 8 in_channels
363
+ _weight = self.unet.conv_in.weight.clone() # [320, 4, 3, 3]
364
+ _bias = self.unet.conv_in.bias.clone() # [320]
365
+ zero_weight = torch.zeros(_weight.shape).to(_weight.device)
366
+ _weight = torch.cat([_weight, zero_weight], dim=1)
367
+ # _weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s)
368
+ # half the activation magnitude
369
+ # _weight *= 0.5
370
+ # new conv_in channel
371
+ _n_convin_out_channel = self.unet.conv_in.out_channels
372
+ _new_conv_in = Conv2d(
373
+ 8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
374
+ )
375
+ _new_conv_in.weight = Parameter(_weight)
376
+ _new_conv_in.bias = Parameter(_bias)
377
+ self.unet.conv_in = _new_conv_in
378
+ logging.info("Unet conv_in layer is replaced")
379
+ # replace config
380
+ self.unet.config["in_channels"] = 8
381
+ logging.info("Unet config is updated")
382
+ return
383
+
384
+ def _replace_unet_conv_out(self):
385
+ # replace the first layer to accept 8 in_channels
386
+ _weight = self.unet.conv_out.weight.clone() # [8, 320, 3, 3]
387
+ _bias = self.unet.conv_out.bias.clone() # [320]
388
+ _weight = _weight.repeat((2, 1, 1, 1)) # Keep selected channel(s)
389
+ _bias = _bias.repeat((2))
390
+ # half the activation magnitude
391
+
392
+ # new conv_in channel
393
+ _n_convin_out_channel = self.unet.conv_out.out_channels
394
+ _new_conv_out = Conv2d(
395
+ _n_convin_out_channel, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
396
+ )
397
+ _new_conv_out.weight = Parameter(_weight)
398
+ _new_conv_out.bias = Parameter(_bias)
399
+ self.unet.conv_out = _new_conv_out
400
+ logging.info("Unet conv_out layer is replaced")
401
+ # replace config
402
+ self.unet.config["out_channels"] = 8
403
+ logging.info("Unet config is updated")
404
+ return
405
+
406
+ def _check_inference_step(self, n_step: int) -> None:
407
+ """
408
+ Check if denoising step is reasonable
409
+ Args:
410
+ n_step (`int`): denoising steps
411
+ """
412
+ assert n_step >= 1
413
+
414
+ if isinstance(self.scheduler, DDIMScheduler):
415
+ if n_step < 10:
416
+ logging.warning(
417
+ f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference."
418
+ )
419
+ elif isinstance(self.scheduler, LCMScheduler):
420
+ if not 1 <= n_step <= 4:
421
+ logging.warning(
422
+ f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps."
423
+ )
424
+ elif isinstance(self.scheduler, PNDMScheduler):
425
+ if n_step < 10:
426
+ logging.warning(
427
+ f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference."
428
+ )
429
+ else:
430
+ raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
431
+
432
+ def encode_empty_text(self):
433
+ """
434
+ Encode text embedding for empty prompt
435
+ """
436
+ prompt = ""
437
+ text_inputs = self.tokenizer(
438
+ prompt,
439
+ padding="max_length",
440
+ max_length=self.tokenizer.model_max_length,
441
+ truncation=True,
442
+ return_tensors="pt",
443
+ )
444
+ text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
445
+ self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
446
+
447
+ def encode_text(self, prompt):
448
+ """
449
+ Encode text embedding for empty prompt
450
+ """
451
+ text_inputs = self.tokenizer(
452
+ prompt,
453
+ padding="max_length",
454
+ max_length=self.tokenizer.model_max_length,
455
+ truncation=True,
456
+ return_tensors="pt",
457
+ )
458
+ text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
459
+ text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
460
+ return text_embed
461
+
462
+ def numpy_to_pil(self, images: np.ndarray) -> PIL.Image.Image:
463
+ """
464
+ Convert a numpy image or a batch of images to a PIL image.
465
+ """
466
+ if images.ndim == 3:
467
+ images = images[None, ...]
468
+ images = (images * 255).round().astype("uint8")
469
+ if images.shape[-1] == 1:
470
+ # special case for grayscale (single channel) images
471
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
472
+ else:
473
+ pil_images = [Image.fromarray(image) for image in images]
474
+
475
+ return pil_images
476
+
477
+ def full_depth_rgb_inpaint(self,
478
+ rgb_in,
479
+ depth_in,
480
+ image_mask,
481
+ text_embed,
482
+ timesteps,
483
+ generator,
484
+ guidance_scale,
485
+ ):
486
+ depth_latent = self.encode_depth(depth_in)
487
+ depth_mask = torch.zeros_like(image_mask)
488
+ depth_mask_latent = self.encode_depth(depth_in)
489
+
490
+ rgb_latent = torch.randn(
491
+ depth_latent.shape,
492
+ device=self.device,
493
+ dtype=self.unet.dtype,
494
+ generator=generator,
495
+ ) * self.rgb_scheduler.init_noise_sigma
496
+ rgb_mask = image_mask
497
+ rgb_mask_latent = self.encode_rgb(rgb_in * (image_mask.squeeze() < 0.5), generator=generator)
498
+
499
+ rgb_mask = torch.nn.functional.interpolate(rgb_mask, size=rgb_latent.shape[-2:])
500
+ depth_mask = torch.nn.functional.interpolate(depth_mask, size=rgb_latent.shape[-2:])
501
+
502
+ for i, t in enumerate(timesteps):
503
+ cat_latent = torch.cat(
504
+ [rgb_latent, rgb_mask, rgb_mask_latent, depth_mask_latent, depth_latent, depth_mask, rgb_mask_latent,
505
+ depth_mask_latent], dim=1
506
+ ).float() # [B, 9*2, h, w]
507
+
508
+ latent_model_input = torch.cat([cat_latent] * 2)
509
+
510
+ # predict the noise residual
511
+ with torch.no_grad():
512
+ partial_noise_pred = self.unet(
513
+ latent_model_input,
514
+ rgb_timestep=t,
515
+ depth_timestep=t,
516
+ encoder_hidden_states=text_embed,
517
+ return_dict=False,
518
+ depth2rgb_scale=0.2
519
+ )[0]
520
+ noise_pred = self.unet(
521
+ latent_model_input,
522
+ rgb_timestep=t,
523
+ depth_timestep=t,
524
+ encoder_hidden_states=text_embed,
525
+ return_dict=False,
526
+ # separate_list=self.separate_list
527
+ )[0]
528
+ # perform guidance
529
+ rgb_pred_wo_depth_text = partial_noise_pred[0, :4, :, :]
530
+ rgb_pred_wo_text = noise_pred[0, :4, :, :]
531
+ rgb_pred = noise_pred[1, :4, :, :]
532
+ noise_pred = rgb_pred_wo_depth_text + 2 * (rgb_pred_wo_text - rgb_pred_wo_depth_text) + 3 * (rgb_pred - rgb_pred_wo_text)
533
+
534
+ # compute the previous noisy sample x_t -> x_t-1
535
+ rgb_latent = self.rgb_scheduler.step(noise_pred, t, rgb_latent).prev_sample
536
+ return rgb_latent, depth_latent
537
+
538
+ def full_rgb_depth_inpaint(self,
539
+ rgb_in,
540
+ depth_in,
541
+ image_mask,
542
+ text_embed,
543
+ timesteps,
544
+ generator,
545
+ guidance_scale
546
+ ):
547
+ rgb_latent = self.encode_rgb(rgb_in)
548
+ rgb_mask = torch.zeros_like(image_mask)
549
+ rgb_mask_latent = self.encode_rgb(rgb_in)
550
+
551
+ depth_latent = torch.randn(
552
+ rgb_latent.shape,
553
+ device=self.device,
554
+ dtype=self.unet.dtype,
555
+ generator=generator,
556
+ ) * self.depth_scheduler.init_noise_sigma
557
+ depth_mask = image_mask
558
+ depth_mask_latent = self.encode_depth(depth_in * (image_mask.squeeze() < 0.5))
559
+
560
+ rgb_mask = torch.nn.functional.interpolate(rgb_mask, size=rgb_latent.shape[-2:])
561
+ depth_mask = torch.nn.functional.interpolate(depth_mask, size=rgb_latent.shape[-2:])
562
+
563
+ for i, t in enumerate(timesteps):
564
+ cat_latent = torch.cat(
565
+ [rgb_latent, rgb_mask, rgb_mask_latent, depth_mask_latent, depth_latent, depth_mask, rgb_mask_latent,
566
+ depth_mask_latent], dim=1
567
+ ).float() # [B, 9*2, h, w]
568
+
569
+ latent_model_input = torch.cat([cat_latent] * 2)
570
+
571
+ # predict the noise residual
572
+ with torch.no_grad():
573
+ partial_noise_pred = self.unet(
574
+ latent_model_input,
575
+ rgb_timestep=t,
576
+ depth_timestep=t,
577
+ encoder_hidden_states=text_embed,
578
+ return_dict=False,
579
+ rgb2depth_scale=0.2
580
+ )[0]
581
+ noise_pred = self.unet(
582
+ latent_model_input,
583
+ rgb_timestep=t,
584
+ depth_timestep=t,
585
+ encoder_hidden_states=text_embed,
586
+ return_dict=False,
587
+ # separate_list=self.separate_list
588
+ )[0]
589
+ # compute the previous noisy sample x_t -> x_t-1
590
+ depth_pre_wo_rgb = partial_noise_pred[1, 4:, :, :]
591
+
592
+ depth_pre = depth_pre_wo_rgb + 4 * (noise_pred[1, 4:, :, :] - depth_pre_wo_rgb)
593
+
594
+ depth_latent = self.depth_scheduler.step(depth_pre, t, depth_latent, generator=generator).prev_sample
595
+ return rgb_latent, depth_latent
596
+
597
+ def joint_inpaint(self,
598
+ rgb_in,
599
+ depth_in,
600
+ image_mask,
601
+ text_embed,
602
+ timesteps,
603
+ generator,
604
+ guidance_scale
605
+ ):
606
+ bs = rgb_in.shape[0]
607
+ h, w = int(rgb_in.shape[-2]/8), int(rgb_in.shape[-1]/8)
608
+ rgb_latent = torch.randn(
609
+ [bs, 4, h, w],
610
+ device=self.device,
611
+ dtype=self.unet.dtype,
612
+ generator=generator,
613
+ ) * self.rgb_scheduler.init_noise_sigma
614
+ rgb_mask = image_mask
615
+ rgb_mask_latent = self.encode_rgb(rgb_in * (rgb_mask.squeeze() < 0.5), generator=generator)
616
+
617
+ depth_latent = torch.randn(
618
+ [bs, 4, h, w],
619
+ device=self.device,
620
+ dtype=self.unet.dtype,
621
+ generator=generator,
622
+ ) * self.depth_scheduler.init_noise_sigma
623
+ depth_mask = image_mask
624
+ depth_mask_latent = self.encode_depth(depth_in * (image_mask.squeeze() < 0.5))
625
+
626
+ rgb_mask = torch.nn.functional.interpolate(rgb_mask, size=rgb_latent.shape[-2:])
627
+ depth_mask = torch.nn.functional.interpolate(depth_mask, size=rgb_latent.shape[-2:])
628
+
629
+ for i, t in enumerate(timesteps):
630
+ cat_latent = torch.cat(
631
+ [rgb_latent, rgb_mask, rgb_mask_latent, depth_mask_latent, depth_latent, depth_mask, rgb_mask_latent, depth_mask_latent], dim=1
632
+ ).float() # [B, 9*2, h, w]
633
+
634
+ latent_model_input = torch.cat([cat_latent] * 2)
635
+ # predict the noise residual
636
+ with torch.no_grad():
637
+ partial_noise_pred = self.unet(
638
+ latent_model_input,
639
+ rgb_timestep=t,
640
+ depth_timestep=t,
641
+ encoder_hidden_states=text_embed,
642
+ return_dict=False,
643
+ depth2rgb_scale=0,
644
+ rgb2depth_scale=0.2
645
+ )[0]
646
+ noise_pred = self.unet(
647
+ latent_model_input,
648
+ rgb_timestep=t,
649
+ depth_timestep=t,
650
+ encoder_hidden_states=text_embed,
651
+ return_dict=False,
652
+ )[0]
653
+
654
+ # perform guidance
655
+ noise_pred_untext_undual, noise_pred_undual = partial_noise_pred.chunk(2)
656
+ noise_pred_untext, noise_pred_cond = noise_pred.chunk(2)
657
+
658
+ # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
659
+ depth_noise_pred = noise_pred_undual + 3 * (noise_pred_cond - noise_pred_undual)
660
+
661
+ rgb_latent = self.rgb_scheduler.step(noise_pred_cond[:, :4, :, :], t, rgb_latent, return_dict=False)[0]
662
+ depth_latent = self.depth_scheduler.step(depth_noise_pred[:, 4:, :, :], t, depth_latent, generator=generator, return_dict=False)[0]
663
+ return rgb_latent, depth_latent
664
+
665
+ @torch.no_grad()
666
+ def _rgbd_inpaint(self,
667
+ input_image: [torch.Tensor, PIL.Image.Image],
668
+ depth_image: [torch.Tensor, PIL.Image.Image],
669
+ mask: [torch.Tensor, PIL.Image.Image],
670
+ prompt: str = '',
671
+ guidance_scale: float = 4.5,
672
+ generator: Union[torch.Generator, None] = None,
673
+ num_inference_steps: int = 50,
674
+ resample_method: str = "bilinear",
675
+ processing_res: int = 512,
676
+ mode: str = 'full_depth_rgb_inpaint'
677
+ ) -> PIL.Image:
678
+ self._check_inference_step(num_inference_steps)
679
+
680
+ resample_method: InterpolationMode = get_tv_resample_method(resample_method)
681
+
682
+ # ----------------- encoder prompt -----------------
683
+ if isinstance(prompt, list):
684
+ bs = len(prompt)
685
+ batch_text_embed = []
686
+ for p in prompt:
687
+ batch_text_embed.append(self.encode_text(p))
688
+ batch_text_embed = torch.cat(batch_text_embed, dim=0)
689
+ elif isinstance(prompt, str):
690
+ bs = 1
691
+ batch_text_embed = self.encode_text(prompt).unsqueeze(0)
692
+ else:
693
+ raise NotImplementedError
694
+
695
+ if self.empty_text_embed is None:
696
+ self.encode_empty_text()
697
+ batch_empty_text_embed = self.empty_text_embed.repeat(
698
+ (batch_text_embed.shape[0], 1, 1)
699
+ ).to(self.device) # [B, 2, 1024]
700
+ text_embed = torch.cat([batch_empty_text_embed, batch_text_embed], dim=0)
701
+
702
+ # ----------------- Image Preprocess -----------------
703
+ # Convert to torch tensor
704
+ if isinstance(input_image, Image.Image):
705
+ rgb_in = self.image_processor.preprocess(input_image, height=processing_res,
706
+ width=processing_res).to(self.dtype).to(self.device)
707
+ elif isinstance(input_image, torch.Tensor):
708
+ rgb = input_image.unsqueeze(0)
709
+ input_size = rgb.shape
710
+ assert (
711
+ 4 == rgb.dim() and 3 == input_size[-3]
712
+ ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
713
+ if processing_res > 0:
714
+ rgb = resize(rgb, [processing_res, processing_res], resample_method, antialias=True)
715
+ rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
716
+ rgb_in = rgb_norm.to(self.dtype).to(self.device)
717
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
718
+
719
+ if isinstance(depth_image, Image.Image):
720
+ depth = pil_to_tensor(depth_image)
721
+ depth = depth.unsqueeze(0) # [1, rgb, H, W]
722
+ elif isinstance(depth_image, torch.Tensor):
723
+ if len(depth_image.shape) == 3:
724
+ depth = depth_image.unsqueeze(0)
725
+ else:
726
+ depth = depth_image
727
+ # pdb.set_trace()
728
+ depth = depth.repeat(1, 3, 1, 1)
729
+ input_size = depth.shape
730
+ assert (
731
+ 4 == depth.dim() and 3 == input_size[-3]
732
+ ), f"Wrong input shape {input_size}, expected [1, 1, H, W]"
733
+ if processing_res > 0:
734
+ depth = resize(depth, [processing_res, processing_res], resample_method, antialias=True)
735
+ depth_norm: torch.Tensor = (depth - depth.min()) / (
736
+ depth.max() - depth.min()) * 2.0 - 1.0 # [0, 255] -> [-1, 1]
737
+ depth_in = depth_norm.to(self.dtype).to(self.device)
738
+ assert depth_norm.min() >= -1.0 and depth_norm.max() <= 1.0
739
+
740
+ if (mask.max() - mask.min()) != 0:
741
+ mask = (mask - mask.min()) / (mask.max() - mask.min()) * 255
742
+ image_mask = self.mask_processor.preprocess(mask, height=processing_res, width=processing_res).to(self.device)
743
+
744
+ self.rgb_scheduler.set_timesteps(num_inference_steps, device=self.device)
745
+ self.depth_scheduler.set_timesteps(num_inference_steps, device=self.device)
746
+ timesteps = self.rgb_scheduler.timesteps
747
+
748
+ if mode == 'full_rgb_depth_inpaint':
749
+ rgb_latent, depth_latent = self.full_rgb_depth_inpaint(rgb_in, depth_in, image_mask, text_embed, timesteps,
750
+ generator, guidance_scale=guidance_scale)
751
+ if mode == 'partial_depth_rgb_inpaint':
752
+ rgb_latent, depth_latent = self.partial_depth_rgb_inpaint(rgb_in, depth_in, image_mask, text_embed, timesteps,
753
+ generator, guidance_scale=guidance_scale)
754
+ if mode == 'full_depth_rgb_inpaint':
755
+ rgb_latent, depth_latent = self.full_depth_rgb_inpaint(rgb_in, depth_in, image_mask, text_embed, timesteps,
756
+ generator, guidance_scale=guidance_scale)
757
+ if mode == 'joint_inpaint':
758
+ rgb_latent, depth_latent = self.joint_inpaint(rgb_in, depth_in, image_mask, text_embed, timesteps,
759
+ generator, guidance_scale=guidance_scale)
760
+
761
+ image = self.decode_image(rgb_latent)
762
+ image = self.numpy_to_pil(image)[0]
763
+
764
+ d_image = self.decode_depth(depth_latent)
765
+ d_image = d_image.cpu().permute(0, 2, 3, 1).numpy()
766
+ d_image = (d_image - d_image.min()) / (d_image.max() - d_image.min())
767
+ d_image = self.numpy_to_pil(d_image)[0]
768
+
769
+ depth = depth.squeeze().permute(1, 2, 0).cpu().numpy()
770
+ depth = (depth - depth.min()) / (depth.max() - depth.min())
771
+ ori_depth = self.numpy_to_pil(depth)[0]
772
+
773
+ ori_image = input_image.squeeze().permute(1, 2, 0).cpu().numpy()
774
+ ori_image = self.numpy_to_pil(ori_image/255)[0]
775
+
776
+ image_mask = self.numpy_to_pil(image_mask.permute(0, 2, 3, 1).cpu().numpy())[0]
777
+ cat_image = make_image_grid([ori_image, ori_depth, image_mask, image, d_image], rows=1, cols=5)
778
+ return cat_image
779
+
780
+
781
+ def encode_rgb(self, rgb_in: torch.Tensor, generator=None) -> torch.Tensor:
782
+ """
783
+ Encode RGB image into latent.
784
+
785
+ Args:
786
+ rgb_in (`torch.Tensor`):
787
+ Input RGB image to be encoded.
788
+
789
+ Returns:
790
+ `torch.Tensor`: Image latent.
791
+ """
792
+ # encode
793
+ image_latents = self.vae.encode(rgb_in).latent_dist.sample(generator=generator)
794
+ image_latents = self.vae.config.scaling_factor * image_latents
795
+ return image_latents
796
+
797
+ def encode_depth(self, depth_in: torch.Tensor) -> torch.Tensor:
798
+ """
799
+ Encode RGB image into latent.
800
+
801
+ Args:
802
+ rgb_in (`torch.Tensor`):
803
+ Input RGB image to be encoded.
804
+
805
+ Returns:
806
+ `torch.Tensor`: Image latent.
807
+ """
808
+ # encode
809
+ h = self.vae.encoder(depth_in)
810
+ moments = self.vae.quant_conv(h)
811
+ mean, logvar = torch.chunk(moments, 2, dim=1)
812
+ # scale latent
813
+ depth_latent = mean * self.depth_latent_scale_factor
814
+ return depth_latent
815
+
816
+ def decode_image(self, latents):
817
+ latents = 1 / self.vae.config.scaling_factor * latents
818
+ z = self.vae.post_quant_conv(latents)
819
+ image = self.vae.decoder(z)
820
+ image = (image / 2 + 0.5).clamp(0, 1)
821
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
822
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
823
+ return image
824
+
825
+ def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
826
+ """
827
+ Decode depth latent into depth map.
828
+
829
+ Args:
830
+ depth_latent (`torch.Tensor`):
831
+ Depth latent to be decoded.
832
+
833
+ Returns:
834
+ `torch.Tensor`: Decoded depth map.
835
+ """
836
+ # scale latent
837
+ depth_latent = depth_latent / self.depth_latent_scale_factor
838
+ # decode
839
+ z = self.vae.post_quant_conv(depth_latent)
840
+ stacked = self.vae.decoder(z)
841
+ # mean of output channels
842
+ depth_mean = stacked.mean(dim=1, keepdim=True)
843
+ return depth_mean
844
+
845
+ def post_process_rgbd(self, prompts, rgb_image, depth_image):
846
+
847
+ rgbd_images = []
848
+ for idx, p in enumerate(prompts):
849
+ image1, image2 = rgb_image[idx], depth_image[idx]
850
+
851
+ width1, height1 = image1.size
852
+ width2, height2 = image2.size
853
+
854
+ font = ImageFont.load_default(size=20)
855
+ text = p
856
+ draw = ImageDraw.Draw(image1)
857
+ text_bbox = draw.textbbox((0, 0), text, font=font)
858
+ text_width = text_bbox[2] - text_bbox[0]
859
+ text_height = text_bbox[3] - text_bbox[1]
860
+
861
+ new_image = Image.new('RGB', (width1 + width2, max(height1, height2) + text_height), (255, 255, 255))
862
+
863
+ text_x = (new_image.width - text_width) // 2
864
+ text_y = 0
865
+ draw = ImageDraw.Draw(new_image)
866
+ draw.text((text_x, text_y), text, fill="black", font=font)
867
+
868
+ new_image.paste(image1, (0, text_height))
869
+ new_image.paste(image2, (width1, text_height))
870
+
871
+ rgbd_images.append(pil_to_tensor(new_image))
872
+
873
+ return rgbd_images
marigold/marigold_inpainting_pipeline.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from diffusers import StableDiffusionInpaintPipeline
marigold/marigold_pipeline.py ADDED
@@ -0,0 +1,1194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ # Last modified: 2024-05-24
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # --------------------------------------------------------------------------
16
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
17
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
18
+ # More information about the method can be found at https://marigoldmonodepth.github.io
19
+ # --------------------------------------------------------------------------
20
+
21
+ from typing import Any, Callable, Dict, List, Optional, Union
22
+ import logging
23
+ from diffusers.image_processor import VaeImageProcessor
24
+ import pdb
25
+ from diffusers.utils import load_image, make_image_grid
26
+ from typing import Dict, Optional, Union
27
+ import torchvision.transforms as transforms
28
+ import PIL.Image
29
+ import numpy as np
30
+ import torch
31
+ from diffusers import (
32
+ AutoencoderKL,
33
+ DDIMScheduler,
34
+ DiffusionPipeline,
35
+ LCMScheduler,
36
+ UNet2DConditionModel,
37
+ )
38
+ from .duplicate_unet import DoubleUNet2DConditionModel
39
+ import os
40
+ from torch.nn import Conv2d
41
+ from PIL import Image, ImageDraw, ImageFont
42
+ from torch.nn.parameter import Parameter
43
+ from diffusers.utils import BaseOutput
44
+ from PIL import Image
45
+ from torch.utils.data import DataLoader, TensorDataset
46
+ from torchvision.transforms import InterpolationMode
47
+ from torchvision.transforms.functional import pil_to_tensor, resize
48
+ from tqdm.auto import tqdm
49
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
50
+
51
+ from .util.batchsize import find_batch_size
52
+ from .util.ensemble import ensemble_depth
53
+ from .util.image_util import (
54
+ chw2hwc,
55
+ colorize_depth_maps,
56
+ get_tv_resample_method,
57
+ resize_max_res,
58
+ )
59
+
60
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
61
+ """
62
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
63
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
64
+ """
65
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
66
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
67
+ # rescale the results from guidance (fixes overexposure)
68
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
69
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
70
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
71
+ return noise_cfg
72
+
73
+ class MarigoldDepthOutput(BaseOutput):
74
+ """
75
+ Output class for Marigold monocular depth prediction pipeline.
76
+
77
+ Args:
78
+ depth_np (`np.ndarray`):
79
+ Predicted depth map, with depth values in the range of [0, 1].
80
+ depth_colored (`PIL.Image.Image`):
81
+ Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
82
+ uncertainty (`None` or `np.ndarray`):
83
+ Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
84
+ """
85
+
86
+ depth_np: np.ndarray
87
+ depth_colored: Union[None, Image.Image]
88
+ uncertainty: Union[None, np.ndarray]
89
+
90
+ class MarigoldPipeline(DiffusionPipeline):
91
+ """
92
+ Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
93
+
94
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
95
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
96
+
97
+ Args:
98
+ unet (`UNet2DConditionModel`):
99
+ Conditional U-Net to denoise the depth latent, conditioned on image latent.
100
+ vae (`AutoencoderKL`):
101
+ Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
102
+ to and from latent representations.
103
+ scheduler (`DDIMScheduler`):
104
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
105
+ text_encoder (`CLIPTextModel`):
106
+ Text-encoder, for empty text embedding.
107
+ tokenizer (`CLIPTokenizer`):
108
+ CLIP tokenizer.
109
+ scale_invariant (`bool`, *optional*):
110
+ A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in
111
+ the model config. When used together with the `shift_invariant=True` flag, the model is also called
112
+ "affine-invariant". NB: overriding this value is not supported.
113
+ shift_invariant (`bool`, *optional*):
114
+ A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in
115
+ the model config. When used together with the `scale_invariant=True` flag, the model is also called
116
+ "affine-invariant". NB: overriding this value is not supported.
117
+ default_denoising_steps (`int`, *optional*):
118
+ The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
119
+ quality with the given model. This value must be set in the model config. When the pipeline is called
120
+ without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
121
+ reasonable results with various model flavors compatible with the pipeline, such as those relying on very
122
+ short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
123
+ default_processing_resolution (`int`, *optional*):
124
+ The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
125
+ the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
126
+ default value is used. This is required to ensure reasonable results with various model flavors trained
127
+ with varying optimal processing resolution values.
128
+ """
129
+
130
+ rgb_latent_scale_factor = 0.18215
131
+ depth_latent_scale_factor = 0.18215
132
+
133
+ def __init__(
134
+ self,
135
+ unet: DoubleUNet2DConditionModel,
136
+ vae: AutoencoderKL,
137
+ scheduler: Union[DDIMScheduler, LCMScheduler],
138
+ text_encoder: CLIPTextModel,
139
+ tokenizer: CLIPTokenizer,
140
+ scale_invariant: Optional[bool] = True,
141
+ shift_invariant: Optional[bool] = True,
142
+ default_denoising_steps: Optional[int] = None,
143
+ default_processing_resolution: Optional[int] = None,
144
+ requires_safety_checker: bool = False,
145
+ ):
146
+ super().__init__()
147
+
148
+ self.register_modules(
149
+ unet=unet,
150
+ vae=vae,
151
+ scheduler=scheduler,
152
+ text_encoder=text_encoder,
153
+ tokenizer=tokenizer,
154
+ )
155
+ self.register_to_config(
156
+ scale_invariant=scale_invariant,
157
+ shift_invariant=shift_invariant,
158
+ default_denoising_steps=default_denoising_steps,
159
+ default_processing_resolution=default_processing_resolution,
160
+ )
161
+
162
+ self.scale_invariant = scale_invariant
163
+ self.shift_invariant = shift_invariant
164
+ self.default_denoising_steps = default_denoising_steps
165
+ self.default_processing_resolution = default_processing_resolution
166
+
167
+ self.empty_text_embed = None
168
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
169
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
170
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
171
+ self.separate_list = [0,0]
172
+
173
+ @torch.no_grad()
174
+ def __call__(
175
+ self,
176
+ input_image: Union[Image.Image, torch.Tensor],
177
+ denoising_steps: Optional[int] = None,
178
+ ensemble_size: int = 5,
179
+ processing_res: Optional[int] = None,
180
+ match_input_res: bool = True,
181
+ resample_method: str = "bilinear",
182
+ batch_size: int = 0,
183
+ generator: Union[torch.Generator, None] = None,
184
+ color_map: str = "Spectral",
185
+ show_progress_bar: bool = True,
186
+ ensemble_kwargs: Dict = None,
187
+ ) -> MarigoldDepthOutput:
188
+ """
189
+ Function invoked when calling the pipeline.
190
+
191
+ Args:
192
+ input_image (`Image`):
193
+ Input RGB (or gray-scale) image.
194
+ denoising_steps (`int`, *optional*, defaults to `None`):
195
+ Number of denoising diffusion steps during inference. The default value `None` results in automatic
196
+ selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
197
+ for Marigold-LCM models.
198
+ ensemble_size (`int`, *optional*, defaults to `10`):
199
+ Number of predictions to be ensembled.
200
+ processing_res (`int`, *optional*, defaults to `None`):
201
+ Effective processing resolution. When set to `0`, processes at the original image resolution. This
202
+ produces crisper predictions, but may also lead to the overall loss of global context. The default
203
+ value `None` resolves to the optimal value from the model config.
204
+ match_input_res (`bool`, *optional*, defaults to `True`):
205
+ Resize depth prediction to match input resolution.
206
+ Only valid if `processing_res` > 0.
207
+ resample_method: (`str`, *optional*, defaults to `bilinear`):
208
+ Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
209
+ batch_size (`int`, *optional*, defaults to `0`):
210
+ Inference batch size, no bigger than `num_ensemble`.
211
+ If set to 0, the script will automatically decide the proper batch size.
212
+ generator (`torch.Generator`, *optional*, defaults to `None`)
213
+ Random generator for initial noise generation.
214
+ show_progress_bar (`bool`, *optional*, defaults to `True`):
215
+ Display a progress bar of diffusion denoising.
216
+ color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
217
+ Colormap used to colorize the depth map.
218
+ scale_invariant (`str`, *optional*, defaults to `True`):
219
+ Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction.
220
+ shift_invariant (`str`, *optional*, defaults to `True`):
221
+ Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False, near plane will be fixed at 0m.
222
+ ensemble_kwargs (`dict`, *optional*, defaults to `None`):
223
+ Arguments for detailed ensembling settings.
224
+ Returns:
225
+ `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
226
+ - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
227
+ - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
228
+ - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
229
+ coming from ensembling. None if `ensemble_size = 1`
230
+ """
231
+ # Model-specific optimal default values leading to fast and reasonable results.
232
+ if denoising_steps is None:
233
+ denoising_steps = self.default_denoising_steps
234
+ if processing_res is None:
235
+ processing_res = self.default_processing_resolution
236
+
237
+ assert processing_res >= 0
238
+ assert ensemble_size >= 1
239
+
240
+ # Check if denoising step is reasonable
241
+ self._check_inference_step(denoising_steps)
242
+
243
+ resample_method: InterpolationMode = get_tv_resample_method(resample_method)
244
+
245
+ # ----------------- Image Preprocess -----------------
246
+ # Convert to torch tensor
247
+ if isinstance(input_image, Image.Image):
248
+ input_image = input_image.convert("RGB")
249
+ # convert to torch tensor [H, W, rgb] -> [rgb, H, W]
250
+ rgb = pil_to_tensor(input_image)
251
+ rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
252
+ elif isinstance(input_image, torch.Tensor):
253
+ rgb = input_image
254
+ else:
255
+ raise TypeError(f"Unknown input type: {type(input_image) = }")
256
+ input_size = rgb.shape
257
+ assert (
258
+ 4 == rgb.dim() and 3 == input_size[-3]
259
+ ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
260
+
261
+ # Resize image
262
+ if processing_res > 0:
263
+ rgb = resize_max_res(
264
+ rgb,
265
+ max_edge_resolution=processing_res,
266
+ resample_method=resample_method,
267
+ )
268
+
269
+ # Normalize rgb values
270
+ rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
271
+ rgb_norm = rgb_norm.to(self.dtype)
272
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
273
+
274
+ # ----------------- Predicting depth -----------------
275
+ # Batch repeated input image
276
+ duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
277
+ single_rgb_dataset = TensorDataset(duplicated_rgb)
278
+ if batch_size > 0:
279
+ _bs = batch_size
280
+ else:
281
+ _bs = find_batch_size(
282
+ ensemble_size=ensemble_size,
283
+ input_res=max(rgb_norm.shape[1:]),
284
+ dtype=self.dtype,
285
+ )
286
+
287
+ single_rgb_loader = DataLoader(
288
+ single_rgb_dataset, batch_size=_bs, shuffle=False
289
+ )
290
+
291
+ # Predict depth maps (batched)
292
+ depth_pred_ls = []
293
+ if show_progress_bar:
294
+ iterable = tqdm(
295
+ single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
296
+ )
297
+ else:
298
+ iterable = single_rgb_loader
299
+ for batch in iterable:
300
+ (batched_img,) = batch
301
+ depth_pred_raw = self.single_infer(
302
+ rgb_in=batched_img,
303
+ num_inference_steps=denoising_steps,
304
+ show_pbar=show_progress_bar,
305
+ generator=generator,
306
+ )
307
+ depth_pred_ls.append(depth_pred_raw.detach())
308
+ depth_preds = torch.concat(depth_pred_ls, dim=0)
309
+ torch.cuda.empty_cache() # clear vram cache for ensembling
310
+
311
+ # ----------------- Test-time ensembling -----------------
312
+ if ensemble_size > 1:
313
+ depth_pred, pred_uncert = ensemble_depth(
314
+ depth_preds,
315
+ scale_invariant=self.scale_invariant,
316
+ shift_invariant=self.shift_invariant,
317
+ max_res=50,
318
+ **(ensemble_kwargs or {}),
319
+ )
320
+ else:
321
+ depth_pred = depth_preds
322
+ pred_uncert = None
323
+
324
+ # Resize back to original resolution
325
+ if match_input_res:
326
+ depth_pred = resize(
327
+ depth_pred,
328
+ input_size[-2:],
329
+ interpolation=resample_method,
330
+ antialias=True,
331
+ )
332
+
333
+ # Convert to numpy
334
+ depth_pred = depth_pred.squeeze()
335
+ depth_pred = depth_pred.cpu().numpy()
336
+ if pred_uncert is not None:
337
+ pred_uncert = pred_uncert.squeeze().cpu().numpy()
338
+
339
+ # Clip output range
340
+ depth_pred = depth_pred.clip(0, 1)
341
+
342
+ # Colorize
343
+ if color_map is not None:
344
+ depth_colored = colorize_depth_maps(
345
+ depth_pred, 0, 1, cmap=color_map
346
+ ).squeeze() # [3, H, W], value in (0, 1)
347
+ depth_colored = (depth_colored * 255).astype(np.uint8)
348
+ depth_colored_hwc = chw2hwc(depth_colored)
349
+ depth_colored_img = Image.fromarray(depth_colored_hwc)
350
+ else:
351
+ depth_colored_img = None
352
+
353
+ return MarigoldDepthOutput(
354
+ depth_np=depth_pred,
355
+ depth_colored=depth_colored_img,
356
+ uncertainty=pred_uncert,
357
+ )
358
+
359
+ def _replace_unet_conv_in(self):
360
+ # replace the first layer to accept 8 in_channels
361
+ _weight = self.unet.conv_in.weight.clone() # [320, 4, 3, 3]
362
+ _bias = self.unet.conv_in.bias.clone() # [320]
363
+ zero_weight = torch.zeros(_weight.shape).to(_weight.device)
364
+ _weight = torch.cat([_weight, zero_weight], dim=1)
365
+ # _weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s)
366
+ # half the activation magnitude
367
+ # _weight *= 0.5
368
+ # new conv_in channel
369
+ _n_convin_out_channel = self.unet.conv_in.out_channels
370
+ _new_conv_in = Conv2d(
371
+ 8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
372
+ )
373
+ _new_conv_in.weight = Parameter(_weight)
374
+ _new_conv_in.bias = Parameter(_bias)
375
+ self.unet.conv_in = _new_conv_in
376
+ logging.info("Unet conv_in layer is replaced")
377
+ # replace config
378
+ self.unet.config["in_channels"] = 8
379
+ logging.info("Unet config is updated")
380
+ return
381
+
382
+ def _replace_unet_conv_out(self):
383
+ # replace the first layer to accept 8 in_channels
384
+ _weight = self.unet.conv_out.weight.clone() # [8, 320, 3, 3]
385
+ _bias = self.unet.conv_out.bias.clone() # [320]
386
+ _weight = _weight.repeat((2, 1, 1, 1)) # Keep selected channel(s)
387
+ _bias = _bias.repeat((2))
388
+ # half the activation magnitude
389
+ # new conv_in channel
390
+ _n_convin_out_channel = self.unet.conv_out.out_channels
391
+ _new_conv_out = Conv2d(
392
+ _n_convin_out_channel, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
393
+ )
394
+ _new_conv_out.weight = Parameter(_weight)
395
+ _new_conv_out.bias = Parameter(_bias)
396
+ self.unet.conv_out = _new_conv_out
397
+ logging.info("Unet conv_out layer is replaced")
398
+ # replace config
399
+ self.unet.config["out_channels"] = 8
400
+ logging.info("Unet config is updated")
401
+ return
402
+
403
+ def _check_inference_step(self, n_step: int) -> None:
404
+ """
405
+ Check if denoising step is reasonable
406
+ Args:
407
+ n_step (`int`): denoising steps
408
+ """
409
+ assert n_step >= 1
410
+
411
+ # if isinstance(self.scheduler, DDIMScheduler):
412
+ # if n_step < 10:
413
+ # logging.warning(
414
+ # f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference."
415
+ # )
416
+ # elif isinstance(self.scheduler, LCMScheduler):
417
+ # if not 1 <= n_step <= 4:
418
+ # logging.warning(
419
+ # f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps."
420
+ # )
421
+ # else:
422
+ # raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
423
+
424
+ def encode_empty_text(self):
425
+ """
426
+ Encode text embedding for empty prompt
427
+ """
428
+ prompt = ""
429
+ text_inputs = self.tokenizer(
430
+ prompt,
431
+ padding="max_length",
432
+ max_length=self.tokenizer.model_max_length,
433
+ truncation=True,
434
+ return_tensors="pt",
435
+ )
436
+ text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
437
+ self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
438
+
439
+ def encode_text(self, prompt):
440
+ """
441
+ Encode text embedding for empty prompt
442
+ """
443
+ text_inputs = self.tokenizer(
444
+ prompt,
445
+ padding="max_length",
446
+ max_length=self.tokenizer.model_max_length,
447
+ truncation=True,
448
+ return_tensors="pt",
449
+ )
450
+ text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
451
+ text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
452
+ return text_embed
453
+
454
+ def numpy_to_pil(self, images: np.ndarray) -> PIL.Image.Image:
455
+ """
456
+ Convert a numpy image or a batch of images to a PIL image.
457
+ """
458
+ if images.ndim == 3:
459
+ images = images[None, ...]
460
+ images = (images * 255).round().astype("uint8")
461
+ if images.shape[-1] == 1:
462
+ # special case for grayscale (single channel) images
463
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
464
+ else:
465
+ pil_images = [Image.fromarray(image) for image in images]
466
+
467
+ return pil_images
468
+
469
+ @torch.no_grad()
470
+ def generate_rgbd(
471
+ self,
472
+ prompt: str or list,
473
+ num_inference_steps: int,
474
+ generator: Union[torch.Generator, None],
475
+ show_pbar: bool = None,
476
+ color_map: str = "Spectral",
477
+ height: int = 60,
478
+ width: int = 80
479
+ ):
480
+ """
481
+ Perform an individual depth prediction without ensembling.
482
+
483
+ Args:
484
+ rgb_in (`torch.Tensor`):
485
+ Input RGB image.
486
+ num_inference_steps (`int`):
487
+ Number of diffusion denoisign steps (DDIM) during inference.
488
+ show_pbar (`bool`):
489
+ Display a progress bar of diffusion denoising.
490
+ generator (`torch.Generator`)
491
+ Random generator for initial noise generation.
492
+ Returns:
493
+ `torch.Tensor`: Predicted depth map.
494
+ """
495
+ device = self.device
496
+ ori_type = self.dtype
497
+
498
+ # Set timesteps
499
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
500
+ timesteps = self.scheduler.timesteps # [T]
501
+
502
+ if isinstance(prompt, list):
503
+ bs = len(prompt)
504
+ batch_text_embed = []
505
+ for p in prompt:
506
+ batch_text_embed.append(self.encode_text(p))
507
+ batch_text_embed = torch.cat(batch_text_embed, dim=0)
508
+ elif isinstance(prompt, str):
509
+ bs = 1
510
+ batch_text_embed = self.encode_text(prompt).unsqueeze(0)
511
+ else:
512
+ raise NotImplementedError
513
+
514
+ if self.empty_text_embed is None:
515
+ self.encode_empty_text()
516
+ batch_empty_text_embed = self.empty_text_embed.repeat(
517
+ (batch_text_embed.shape[0], 1, 1)
518
+ ).to(device) # [B, 2, 1024]
519
+
520
+ text_embed = torch.cat([batch_empty_text_embed, batch_text_embed], dim=0)
521
+
522
+ # Initial depth map (noise)
523
+ cat_latent = torch.randn(
524
+ [bs, self.unet.config["in_channels"], height, width],
525
+ device=device,
526
+ dtype=torch.bfloat16,
527
+ generator=generator,
528
+ ) * self.scheduler.init_noise_sigma # [B, 8, h, w]
529
+
530
+ # Denoising loop
531
+ if show_pbar:
532
+ iterable = tqdm(
533
+ enumerate(timesteps),
534
+ total=len(timesteps),
535
+ leave=False,
536
+ desc=" " * 4 + "Diffusion denoising",
537
+ )
538
+ else:
539
+ iterable = enumerate(timesteps)
540
+
541
+ self.to(torch.bfloat16)
542
+ for i, t in iterable:
543
+ latent_model_input = torch.cat([cat_latent] * 2)
544
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
545
+
546
+ # predict the noise residual
547
+ with torch.no_grad():
548
+ noise_pred = self.unet(
549
+ latent_model_input,
550
+ t,
551
+ t,
552
+ encoder_hidden_states=text_embed.to(torch.bfloat16),
553
+ return_dict=False,
554
+ # separate_list=self.separate_list
555
+ )[0]
556
+ # perform guidance
557
+ guidance_scale = 7.5
558
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
559
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
560
+
561
+ # compute the previous noisy sample x_t -> x_t-1
562
+ cat_latent = self.scheduler.step(noise_pred, t, cat_latent).prev_sample
563
+
564
+ # self.unet.to(default_dtype)
565
+ # cat_latent.to(default_dtype)
566
+
567
+ image = self.decode_image(cat_latent[:, 0:4, :, :])
568
+
569
+ image = self.numpy_to_pil(image)
570
+ # depth_pred = depth
571
+ depth = self.decode_depth(cat_latent[:, 4:, :, :])
572
+ depth = (depth - depth.min()) / (depth.max() - depth.min())
573
+ # depth = torch.clip(depth, -1.0, 1.0)
574
+ # depth = (depth + 1.0) / 2.0
575
+ depth_pred = depth.squeeze()
576
+ depth_pred = depth_pred.float().cpu().numpy()
577
+ depth_pred = depth_pred.clip(0, 1)
578
+
579
+ # Colorize
580
+ if color_map is not None:
581
+ depth_colored_img = []
582
+ depth_colored = colorize_depth_maps(
583
+ depth_pred, 0, 1, cmap=color_map
584
+ ).squeeze() # [3, H, W], value in (0, 1)
585
+ depth_colored_img = self.numpy_to_pil(np.transpose(depth_colored, (0, 2, 3, 1)))
586
+ else:
587
+ depth_colored_img = None
588
+
589
+ rgbd_images = self.post_process_rgbd(prompt, image, depth_colored_img)
590
+ self.to(ori_type)
591
+
592
+ return rgbd_images
593
+
594
+ @torch.no_grad()
595
+ def image2depth(self,
596
+ input_image: Union[Image.Image, torch.Tensor],
597
+ denoising_steps: Optional[int] = None,
598
+ ensemble_size: int = 5,
599
+ processing_res: Optional[int] = None,
600
+ match_input_res: bool = True,
601
+ resample_method: str = "bilinear",
602
+ batch_size: int = 0,
603
+ generator: Union[torch.Generator, None] = None,
604
+ color_map: str = "Spectral",
605
+ show_progress_bar: bool = True,
606
+ ensemble_kwargs: Dict = None,
607
+ cfg_scale: float = 1.0
608
+ ):
609
+ # Model-specific optimal default values leading to fast and reasonable results.
610
+ if denoising_steps is None:
611
+ denoising_steps = self.default_denoising_steps
612
+ if processing_res is None:
613
+ processing_res = self.default_processing_resolution
614
+
615
+ ori_type = self.dtype
616
+ self.to(torch.bfloat16)
617
+
618
+ assert processing_res >= 0
619
+ assert ensemble_size >= 1
620
+
621
+ # Check if denoising step is reasonable
622
+ self._check_inference_step(denoising_steps)
623
+
624
+ resample_method: InterpolationMode = get_tv_resample_method(resample_method)
625
+
626
+ # ----------------- Image Preprocess -----------------
627
+ # Convert to torch tensor
628
+ if isinstance(input_image, Image.Image):
629
+ input_image = input_image.convert("RGB")
630
+ # convert to torch tensor [H, W, rgb] -> [rgb, H, W]
631
+ rgb = pil_to_tensor(input_image)
632
+ rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
633
+ elif isinstance(input_image, torch.Tensor):
634
+ rgb = input_image
635
+ else:
636
+ raise TypeError(f"Unknown input type: {type(input_image) = }")
637
+ input_size = rgb.shape
638
+ assert (
639
+ 4 == rgb.dim() and 3 == input_size[-3]
640
+ ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
641
+
642
+ # Resize image
643
+ if processing_res > 0:
644
+ rgb = resize_max_res(
645
+ rgb,
646
+ max_edge_resolution=processing_res,
647
+ resample_method=resample_method,
648
+ )
649
+
650
+ # Normalize rgb values
651
+ rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
652
+ rgb_norm = rgb_norm.to(self.dtype)
653
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
654
+
655
+ # ----------------- Predicting depth -----------------
656
+ # Batch repeated input image
657
+ duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
658
+ single_rgb_dataset = TensorDataset(duplicated_rgb)
659
+ if batch_size > 0:
660
+ _bs = batch_size
661
+ else:
662
+ _bs = find_batch_size(
663
+ ensemble_size=ensemble_size,
664
+ input_res=max(rgb_norm.shape[1:]),
665
+ dtype=self.dtype,
666
+ )
667
+
668
+ single_rgb_loader = DataLoader(
669
+ single_rgb_dataset, batch_size=_bs, shuffle=False
670
+ )
671
+
672
+ # Predict depth maps (batched)
673
+ depth_pred_ls = []
674
+ if show_progress_bar:
675
+ iterable = tqdm(
676
+ single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
677
+ )
678
+ else:
679
+ iterable = single_rgb_loader
680
+ for batch in iterable:
681
+ (batched_img,) = batch
682
+ depth_pred_raw = self.single_image2depth(
683
+ rgb_in=batched_img,
684
+ num_inference_steps=denoising_steps,
685
+ show_pbar=show_progress_bar,
686
+ generator=generator,
687
+ cfg_scale=cfg_scale
688
+ )
689
+ depth_pred_ls.append(depth_pred_raw.detach())
690
+ depth_preds = torch.concat(depth_pred_ls, dim=0)
691
+ torch.cuda.empty_cache() # clear vram cache for ensembling
692
+ depth_preds = depth_preds.to(torch.float32)
693
+ # ----------------- Test-time ensembling -----------------
694
+ if ensemble_size > 1:
695
+ depth_pred, pred_uncert = ensemble_depth(
696
+ depth_preds,
697
+ scale_invariant=self.scale_invariant,
698
+ shift_invariant=self.shift_invariant,
699
+ max_res=50,
700
+ **(ensemble_kwargs or {}),
701
+ )
702
+ else:
703
+ depth_pred = depth_preds
704
+ pred_uncert = None
705
+
706
+ # Resize back to original resolution
707
+ if match_input_res:
708
+ depth_pred = resize(
709
+ depth_pred,
710
+ input_size[-2:],
711
+ interpolation=resample_method,
712
+ antialias=True,
713
+ )
714
+
715
+ # Convert to numpy
716
+ depth_pred = depth_pred.squeeze()
717
+ depth_pred = depth_pred.cpu().numpy()
718
+ if pred_uncert is not None:
719
+ pred_uncert = pred_uncert.squeeze().cpu().numpy()
720
+
721
+ # Clip output range
722
+ depth_pred = depth_pred.clip(0, 1)
723
+
724
+ # Colorize
725
+ if color_map is not None:
726
+ depth_colored = colorize_depth_maps(
727
+ depth_pred, 0, 1, cmap=color_map
728
+ ).squeeze() # [3, H, W], value in (0, 1)
729
+ depth_colored = (depth_colored * 255).astype(np.uint8)
730
+ depth_colored_hwc = chw2hwc(depth_colored)
731
+ depth_colored_img = Image.fromarray(depth_colored_hwc)
732
+ else:
733
+ depth_colored_img = None
734
+
735
+ self.to(ori_type)
736
+
737
+ return MarigoldDepthOutput(
738
+ depth_np=depth_pred,
739
+ depth_colored=depth_colored_img,
740
+ uncertainty=pred_uncert,
741
+ )
742
+
743
+ @torch.no_grad()
744
+ def single_image2depth(
745
+ self,
746
+ rgb_in: torch.Tensor,
747
+ num_inference_steps: int,
748
+ generator: Union[torch.Generator, None],
749
+ show_pbar: bool,
750
+ cfg_scale: float = 1.0
751
+ ) -> torch.Tensor:
752
+ """
753
+ Perform an individual depth prediction without ensembling.
754
+
755
+ Args:
756
+ rgb_in (`torch.Tensor`):
757
+ Input RGB image.
758
+ num_inference_steps (`int`):
759
+ Number of diffusion denoisign steps (DDIM) during inference.
760
+ show_pbar (`bool`):
761
+ Display a progress bar of diffusion denoising.
762
+ generator (`torch.Generator`)
763
+ Random generator for initial noise generation.
764
+ Returns:
765
+ `torch.Tensor`: Predicted depth map.
766
+ """
767
+ device = self.device
768
+ rgb_in = rgb_in.to(device)
769
+
770
+ # Set timesteps
771
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
772
+ timesteps = self.scheduler.timesteps # [T]
773
+ # Encode image
774
+ rgb_latent = self.encode_rgb(rgb_in)
775
+
776
+ # Initial depth map (noise)
777
+ depth_latent = torch.randn(
778
+ rgb_latent.shape,
779
+ device=device,
780
+ dtype=self.dtype,
781
+ generator=generator,
782
+ ) * self.scheduler.init_noise_sigma # [B, 4, h, w]
783
+
784
+ # Batched empty text embedding
785
+ if self.empty_text_embed is None:
786
+ self.encode_empty_text()
787
+ batch_empty_text_embed = self.empty_text_embed.repeat(
788
+ (rgb_latent.shape[0], 1, 1)
789
+ ).to(device).to(self.dtype) # [B, 2, 1024]
790
+
791
+ # Denoising loop
792
+ if show_pbar:
793
+ iterable = tqdm(
794
+ enumerate(timesteps),
795
+ total=len(timesteps),
796
+ leave=False,
797
+ desc=" " * 4 + "Diffusion denoising",
798
+ )
799
+ else:
800
+ iterable = enumerate(timesteps)
801
+
802
+ for i, t in iterable:
803
+ unet_input = torch.cat(
804
+ [rgb_latent, depth_latent], dim=1
805
+ ) # this order is important
806
+ # predict the noise residual
807
+ noise_pred = self.unet(
808
+ unet_input, rgb_timestep=0, depth_timestep=t, encoder_hidden_states=batch_empty_text_embed
809
+ ).sample # [B, 4, h, w]
810
+
811
+ if cfg_scale > 1:
812
+ uncond_noise_pred = self.unet(
813
+ unet_input, rgb_timestep=0, depth_timestep=t, encoder_hidden_states=batch_empty_text_embed, rgb2depth_scale=0.3
814
+ ).sample # [B, 4, h, w]
815
+
816
+ uncond_pred = uncond_noise_pred[:, 4:, :, :]
817
+ cond_pred = noise_pred[:, 4:, :, :]
818
+
819
+ cond_pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred)
820
+ else:
821
+ cond_pred = noise_pred[:, 4:, :, :]
822
+
823
+ # compute the previous noisy sample x_t -> x_t-1
824
+ depth_latent = self.scheduler.step(
825
+ cond_pred, t, depth_latent
826
+ ).prev_sample
827
+
828
+ depth = self.decode_depth(depth_latent)
829
+
830
+ # clip prediction
831
+ depth = torch.clip(depth, -1.0, 1.0)
832
+ # shift to [0, 1]
833
+ depth = (depth + 1.0) / 2.0
834
+
835
+ return depth
836
+ @torch.no_grad()
837
+ def rgbd2rgbd(self,
838
+ input_image:[torch.Tensor, PIL.Image.Image],
839
+ depth_image:[torch.Tensor, PIL.Image.Image],
840
+ prompt: str = '',
841
+ guidance_scale: float = 7.5,
842
+ strength: float = 0.75,
843
+ generator: Union[torch.Generator, None] = None,
844
+ num_inference_steps: int = 50,
845
+ show_pbar: bool = False,
846
+ resample_method: str = "bilinear",
847
+ processing_res: int = 768
848
+ ) -> torch.Tensor:
849
+ self._check_inference_step(num_inference_steps)
850
+
851
+ resample_method: InterpolationMode = get_tv_resample_method(resample_method)
852
+
853
+ # ----------------- encoder prompt -----------------
854
+ if isinstance(prompt, list):
855
+ bs = len(prompt)
856
+ batch_text_embed = []
857
+ for p in prompt:
858
+ batch_text_embed.append(self.encode_text(p))
859
+ batch_text_embed = torch.cat(batch_text_embed, dim=0)
860
+ elif isinstance(prompt, str):
861
+ bs = 1
862
+ batch_text_embed = self.encode_text(prompt).unsqueeze(0)
863
+ else:
864
+ raise NotImplementedError
865
+
866
+ if self.empty_text_embed is None:
867
+ self.encode_empty_text()
868
+ batch_empty_text_embed = self.empty_text_embed.repeat(
869
+ (batch_text_embed.shape[0], 1, 1)
870
+ ).to(self.device) # [B, 2, 1024]
871
+ text_embed = torch.cat([batch_empty_text_embed, batch_text_embed], dim=0)
872
+
873
+ # ----------------- Image Preprocess -----------------
874
+ # Convert to torch tensor
875
+ rgb = input_image
876
+ input_size = rgb.shape
877
+ assert (
878
+ 4 == rgb.dim() and 3 == input_size[-3]
879
+ ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
880
+ if processing_res > 0:
881
+ rgb = resize_max_res(
882
+ rgb,
883
+ max_edge_resolution=processing_res,
884
+ resample_method=resample_method,
885
+ )
886
+ rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
887
+ rgb_in = rgb_norm.to(self.dtype).to(self.device)
888
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
889
+
890
+ depth = depth_image
891
+ depth = depth.repeat(1, 3, 1, 1)
892
+ input_size = depth.shape
893
+ assert (
894
+ 4 == depth.dim() and 3 == input_size[-3]
895
+ ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
896
+ if processing_res > 0:
897
+ depth = resize_max_res(
898
+ depth,
899
+ max_edge_resolution=processing_res,
900
+ resample_method=resample_method,
901
+ )
902
+ depth_norm: torch.Tensor = (depth - depth.min()) / (depth.max() - depth.min()) * 2.0 - 1.0 # [0, 255] -> [-1, 1]
903
+ depth_in = depth_norm.to(self.dtype).to(self.device)
904
+ assert depth_norm.min() >= -1.0 and depth_norm.max() <= 1.0
905
+
906
+ # Set timesteps
907
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
908
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
909
+ t_start = max(num_inference_steps - init_timestep, 0)
910
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
911
+ num_inference_steps = num_inference_steps - t_start
912
+ latent_timestep = timesteps[:1]
913
+
914
+ # Encode depth
915
+ rgb_latent = self.encode_rgb(rgb_in)
916
+ depth_latent = self.encode_depth(depth_in)
917
+ input_latent = torch.cat([rgb_latent, depth_latent], dim=1)
918
+ noise = torch.randn(
919
+ input_latent.shape,
920
+ device=self.device,
921
+ dtype=self.dtype,
922
+ generator=generator,
923
+ )
924
+
925
+ cat_latent = self.scheduler.add_noise(input_latent, noise, latent_timestep)
926
+ # noisy_latent = self.scheduler.add_noise(rgb_latent, noise, latent_timestep)
927
+ # cat_latent = torch.cat([noisy_latent, depth_latent], dim=1)
928
+
929
+ for i, t in enumerate(timesteps):
930
+ latent_model_input = torch.cat([cat_latent] * 2)
931
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
932
+
933
+ # predict the noise residual
934
+ with torch.no_grad():
935
+ noise_pred = self.unet(
936
+ latent_model_input,
937
+ rgb_timestep=t,
938
+ depth_timestep=t,
939
+ encoder_hidden_states=text_embed,
940
+ return_dict=False,
941
+ # separate_list=self.separate_list
942
+ )[0]
943
+ # perform guidance
944
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
945
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
946
+
947
+ # compute the previous noisy sample x_t -> x_t-1
948
+ cat_latent = self.scheduler.step(noise_pred, t, cat_latent).prev_sample
949
+
950
+ image = self.decode_image(cat_latent[:, :4, :, :])
951
+ image = self.numpy_to_pil(image)
952
+ d_image = self.decode_depth(cat_latent[:, 4:, :, :])
953
+ d_image = d_image.cpu().permute(0, 2, 3, 1).numpy()
954
+ for i in range(len(prompt)):
955
+ d_image[i] = (d_image[i] - d_image[i].min()) / (d_image[i].max() - d_image[i].min())
956
+ d_image = self.numpy_to_pil(d_image)
957
+
958
+ cat_image = make_image_grid([image[0], d_image[0]], rows=1, cols=2)
959
+ return cat_image
960
+
961
+ @torch.no_grad()
962
+ def single_depth2image(
963
+ self,
964
+ depth_image: [torch.Tensor, PIL.Image.Image],
965
+ prompt: str = '',
966
+ generator: Union[torch.Generator, None] = None,
967
+ num_inference_steps: int = 50,
968
+ show_pbar: bool = False,
969
+ resample_method: str = "bilinear",
970
+ processing_res: int = 640
971
+ ) -> torch.Tensor:
972
+ """
973
+ Perform an individual depth prediction without ensembling.
974
+
975
+ Args:
976
+ rgb_in (`torch.Tensor`):
977
+ Input RGB image.
978
+ num_inference_steps (`int`):
979
+ Number of diffusion denoisign steps (DDIM) during inference.
980
+ show_pbar (`bool`):
981
+ Display a progress bar of diffusion denoising.
982
+ generator (`torch.Generator`)
983
+ Random generator for initial noise generation.
984
+ Returns:
985
+ `torch.Tensor`: Predicted depth map.
986
+ """
987
+ device = self.device
988
+ ori_type = self.dtype
989
+ # Check if denoising step is reasonable
990
+ self._check_inference_step(num_inference_steps)
991
+
992
+ resample_method: InterpolationMode = get_tv_resample_method(resample_method)
993
+
994
+ # ----------------- Image Preprocess -----------------
995
+ # Convert to torch tensor
996
+ if isinstance(depth_image, Image.Image):
997
+ depth = pil_to_tensor(depth_image)
998
+ depth = depth.unsqueeze(0) # [1, rgb, H, W]
999
+ elif isinstance(depth_image, torch.Tensor):
1000
+ depth = depth_image
1001
+ else:
1002
+ raise TypeError(f"Unknown input type: {type(depth_image) = }")
1003
+ depth = depth.repeat(1, 3, 1, 1)
1004
+ input_size = depth.shape
1005
+ assert (
1006
+ 4 == depth.dim() and 3 == input_size[-3]
1007
+ ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
1008
+
1009
+ # Resize image
1010
+ if processing_res > 0:
1011
+ depth = resize_max_res(
1012
+ depth,
1013
+ max_edge_resolution=processing_res,
1014
+ resample_method=resample_method,
1015
+ )
1016
+
1017
+ # Normalize rgb values
1018
+ depth_norm: torch.Tensor = depth / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
1019
+ assert depth_norm.min() >= -1.0 and depth_norm.max() <= 1.0
1020
+ depth_in = depth_norm.to(ori_type).to(device)
1021
+
1022
+ # Set timesteps
1023
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1024
+ timesteps = self.scheduler.timesteps # [T]
1025
+
1026
+ # Encode depth
1027
+ depth_latent = self.encode_depth(depth_in)
1028
+
1029
+ # Initial rgb map (noise)
1030
+ rgb_latent = torch.randn(
1031
+ depth_latent.shape,
1032
+ device=device,
1033
+ dtype=ori_type,
1034
+ generator=generator,
1035
+ ) * self.scheduler.init_noise_sigma # [B, 4, h, w]
1036
+
1037
+ # encode text input_ids
1038
+ if isinstance(prompt, list):
1039
+ bs = len(prompt)
1040
+ batch_text_embed = []
1041
+ for p in prompt:
1042
+ batch_text_embed.append(self.encode_text(p))
1043
+ batch_text_embed = torch.cat(batch_text_embed, dim=0)
1044
+ elif isinstance(prompt, str):
1045
+ bs = 1
1046
+ batch_text_embed = self.encode_text(prompt)
1047
+ else:
1048
+ raise NotImplementedError
1049
+
1050
+ if self.empty_text_embed is None:
1051
+ self.encode_empty_text()
1052
+ batch_empty_text_embed = self.empty_text_embed.repeat((batch_text_embed.shape[0], 1, 1)).to(device) # [B, 2, 1024]
1053
+
1054
+ text_embed = torch.cat([batch_empty_text_embed, batch_text_embed], dim=0)
1055
+
1056
+ # Denoising loop
1057
+ if show_pbar:
1058
+ iterable = tqdm(
1059
+ enumerate(timesteps),
1060
+ total=len(timesteps),
1061
+ leave=False,
1062
+ desc=" " * 4 + "Diffusion denoising",
1063
+ )
1064
+ else:
1065
+ iterable = enumerate(timesteps)
1066
+
1067
+ self.unet.to(torch.bfloat16)
1068
+ for i, t in iterable:
1069
+ cat_latent = torch.cat(
1070
+ [rgb_latent, depth_latent], dim=1
1071
+ ) # this order is important
1072
+ latent_model_input = torch.cat([cat_latent] * 2)
1073
+ # predict the noise residual
1074
+ with torch.no_grad():
1075
+ noise_pred = self.unet(
1076
+ latent_model_input.to(torch.bfloat16),
1077
+ rgb_timestep=t,
1078
+ depth_timestep=0,
1079
+ encoder_hidden_states=text_embed.to(torch.bfloat16),
1080
+ return_dict=False,
1081
+ )[0]
1082
+
1083
+ # perform guidance
1084
+ guidance_scale = 7.5
1085
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1086
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1087
+
1088
+ # compute the previous noisy sample x_t -> x_t-1
1089
+ rgb_latent = self.scheduler.step(noise_pred[:, :4, :, :], t, rgb_latent).prev_sample
1090
+
1091
+ image = self.decode_image(rgb_latent)
1092
+ image = self.numpy_to_pil(image)[0]
1093
+ image = image.resize((input_size[-1], input_size[-2]), Image.BILINEAR)
1094
+
1095
+ self.unet.to(ori_type)
1096
+
1097
+ return image
1098
+
1099
+ def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
1100
+ """
1101
+ Encode RGB image into latent.
1102
+
1103
+ Args:
1104
+ rgb_in (`torch.Tensor`):
1105
+ Input RGB image to be encoded.
1106
+
1107
+ Returns:
1108
+ `torch.Tensor`: Image latent.
1109
+ """
1110
+ # encode
1111
+ h = self.vae.encoder(rgb_in)
1112
+ moments = self.vae.quant_conv(h)
1113
+ mean, logvar = torch.chunk(moments, 2, dim=1)
1114
+ # scale latent
1115
+ rgb_latent = mean * self.rgb_latent_scale_factor
1116
+ return rgb_latent
1117
+
1118
+ def encode_depth(self, depth_in: torch.Tensor) -> torch.Tensor:
1119
+ """
1120
+ Encode RGB image into latent.
1121
+
1122
+ Args:
1123
+ rgb_in (`torch.Tensor`):
1124
+ Input RGB image to be encoded.
1125
+
1126
+ Returns:
1127
+ `torch.Tensor`: Image latent.
1128
+ """
1129
+ # encode
1130
+ h = self.vae.encoder(depth_in)
1131
+ moments = self.vae.quant_conv(h)
1132
+ mean, logvar = torch.chunk(moments, 2, dim=1)
1133
+ # scale latent
1134
+ depth_latent = mean * self.depth_latent_scale_factor
1135
+ return depth_latent
1136
+
1137
+ def decode_image(self, latents):
1138
+ latents = 1 / self.vae.config.scaling_factor * latents
1139
+ z = self.vae.post_quant_conv(latents)
1140
+ image = self.vae.decoder(z)
1141
+ image = (image / 2 + 0.5).clamp(0, 1)
1142
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
1143
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
1144
+ return image
1145
+
1146
+ def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
1147
+ """
1148
+ Decode depth latent into depth map.
1149
+
1150
+ Args:
1151
+ depth_latent (`torch.Tensor`):
1152
+ Depth latent to be decoded.
1153
+
1154
+ Returns:
1155
+ `torch.Tensor`: Decoded depth map.
1156
+ """
1157
+ # scale latent
1158
+ depth_latent = depth_latent / self.depth_latent_scale_factor
1159
+ # decode
1160
+ z = self.vae.post_quant_conv(depth_latent)
1161
+ stacked = self.vae.decoder(z)
1162
+ # mean of output channels
1163
+ depth_mean = stacked.mean(dim=1, keepdim=True)
1164
+ return depth_mean
1165
+
1166
+ def post_process_rgbd(self, prompts, rgb_image, depth_image):
1167
+
1168
+ rgbd_images = []
1169
+ for idx, p in enumerate(prompts):
1170
+ image1, image2 = rgb_image[idx], depth_image[idx]
1171
+
1172
+ width1, height1 = image1.size
1173
+ width2, height2 = image2.size
1174
+
1175
+ font = ImageFont.load_default(size=20)
1176
+ text = p
1177
+ draw = ImageDraw.Draw(image1)
1178
+ text_bbox = draw.textbbox((0, 0), text, font=font)
1179
+ text_width = text_bbox[2] - text_bbox[0]
1180
+ text_height = text_bbox[3] - text_bbox[1]
1181
+
1182
+ new_image = Image.new('RGB', (width1 + width2, max(height1, height2) + text_height), (255, 255, 255))
1183
+
1184
+ text_x = (new_image.width - text_width) // 2
1185
+ text_y = 0
1186
+ draw = ImageDraw.Draw(new_image)
1187
+ draw.text((text_x, text_y), text, fill="black", font=font)
1188
+
1189
+ new_image.paste(image1, (0, text_height))
1190
+ new_image.paste(image2, (width1, text_height))
1191
+
1192
+ rgbd_images.append(pil_to_tensor(new_image))
1193
+
1194
+ return rgbd_images
marigold/marigold_xl_pipeline.py ADDED
@@ -0,0 +1,1046 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ # Last modified: 2024-05-24
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # --------------------------------------------------------------------------
16
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
17
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
18
+ # More information about the method can be found at https://marigoldmonodepth.github.io
19
+ # --------------------------------------------------------------------------
20
+
21
+ from typing import Any, Callable, Dict, List, Optional, Union
22
+ import logging
23
+ from diffusers.image_processor import VaeImageProcessor
24
+ import pdb
25
+ from typing import Dict, Optional, Union
26
+ import torchvision.transforms as transforms
27
+ import PIL.Image
28
+ import numpy as np
29
+ import torch
30
+ from diffusers import (
31
+ AutoencoderKL,
32
+ DDIMScheduler,
33
+ DiffusionPipeline,
34
+ LCMScheduler,
35
+ UNet2DConditionModel,
36
+ )
37
+ from .duplicate_unet import DoubleUNet2DConditionModel
38
+ import os
39
+ from torch.nn import Conv2d
40
+ from PIL import Image, ImageDraw, ImageFont
41
+ from torch.nn.parameter import Parameter
42
+ from diffusers.utils import BaseOutput
43
+ from PIL import Image
44
+ from torch.utils.data import DataLoader, TensorDataset
45
+ from torchvision.transforms import InterpolationMode
46
+ from torchvision.transforms.functional import pil_to_tensor, resize
47
+ from tqdm.auto import tqdm
48
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
49
+
50
+ from .util.batchsize import find_batch_size
51
+ from .util.ensemble import ensemble_depth
52
+ from .util.image_util import (
53
+ chw2hwc,
54
+ colorize_depth_maps,
55
+ get_tv_resample_method,
56
+ resize_max_res,
57
+ )
58
+
59
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
60
+ """
61
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
62
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
63
+ """
64
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
65
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
66
+ # rescale the results from guidance (fixes overexposure)
67
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
68
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
69
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
70
+ return noise_cfg
71
+
72
+ class MarigoldDepthOutput(BaseOutput):
73
+ """
74
+ Output class for Marigold monocular depth prediction pipeline.
75
+
76
+ Args:
77
+ depth_np (`np.ndarray`):
78
+ Predicted depth map, with depth values in the range of [0, 1].
79
+ depth_colored (`PIL.Image.Image`):
80
+ Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
81
+ uncertainty (`None` or `np.ndarray`):
82
+ Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
83
+ """
84
+
85
+ depth_np: np.ndarray
86
+ depth_colored: Union[None, Image.Image]
87
+ uncertainty: Union[None, np.ndarray]
88
+
89
+ class MarigoldXLPipeline(DiffusionPipeline):
90
+ """
91
+ Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
92
+
93
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
94
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
95
+
96
+ Args:
97
+ unet (`UNet2DConditionModel`):
98
+ Conditional U-Net to denoise the depth latent, conditioned on image latent.
99
+ vae (`AutoencoderKL`):
100
+ Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
101
+ to and from latent representations.
102
+ scheduler (`DDIMScheduler`):
103
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
104
+ text_encoder (`CLIPTextModel`):
105
+ Text-encoder, for empty text embedding.
106
+ tokenizer (`CLIPTokenizer`):
107
+ CLIP tokenizer.
108
+ scale_invariant (`bool`, *optional*):
109
+ A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in
110
+ the model config. When used together with the `shift_invariant=True` flag, the model is also called
111
+ "affine-invariant". NB: overriding this value is not supported.
112
+ shift_invariant (`bool`, *optional*):
113
+ A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in
114
+ the model config. When used together with the `scale_invariant=True` flag, the model is also called
115
+ "affine-invariant". NB: overriding this value is not supported.
116
+ default_denoising_steps (`int`, *optional*):
117
+ The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
118
+ quality with the given model. This value must be set in the model config. When the pipeline is called
119
+ without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
120
+ reasonable results with various model flavors compatible with the pipeline, such as those relying on very
121
+ short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
122
+ default_processing_resolution (`int`, *optional*):
123
+ The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
124
+ the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
125
+ default value is used. This is required to ensure reasonable results with various model flavors trained
126
+ with varying optimal processing resolution values.
127
+ """
128
+
129
+ rgb_latent_scale_factor = 0.13025
130
+ depth_latent_scale_factor = 0.13025
131
+
132
+ def __init__(
133
+ self,
134
+ unet: DoubleUNet2DConditionModel,
135
+ vae: AutoencoderKL,
136
+ scheduler: Union[DDIMScheduler, LCMScheduler],
137
+ text_encoder: CLIPTextModel,
138
+ text_encoder_2: CLIPTextModelWithProjection,
139
+ tokenizer: CLIPTokenizer,
140
+ tokenizer_2: CLIPTokenizer,
141
+ scale_invariant: Optional[bool] = True,
142
+ shift_invariant: Optional[bool] = True,
143
+ default_denoising_steps: Optional[int] = None,
144
+ default_processing_resolution: Optional[int] = None,
145
+ requires_safety_checker: bool = False,
146
+ ):
147
+ super().__init__()
148
+
149
+ self.register_modules(
150
+ vae=vae,
151
+ text_encoder=text_encoder,
152
+ text_encoder_2=text_encoder_2,
153
+ tokenizer=tokenizer,
154
+ tokenizer_2=tokenizer_2,
155
+ unet=unet,
156
+ scheduler=scheduler,
157
+ )
158
+
159
+ self.register_to_config(
160
+ scale_invariant=scale_invariant,
161
+ shift_invariant=shift_invariant,
162
+ default_denoising_steps=default_denoising_steps,
163
+ default_processing_resolution=default_processing_resolution,
164
+ )
165
+
166
+ self.scale_invariant = scale_invariant
167
+ self.shift_invariant = shift_invariant
168
+ self.default_denoising_steps = default_denoising_steps
169
+ self.default_processing_resolution = default_processing_resolution
170
+
171
+ self.empty_text_embed = None
172
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
173
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
174
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
175
+ self.separate_list = [3,1,3]
176
+ self.default_sample_size = self.unet.config.sample_size
177
+
178
+ @torch.no_grad()
179
+ def __call__(
180
+ self,
181
+ input_image: Union[Image.Image, torch.Tensor],
182
+ denoising_steps: Optional[int] = None,
183
+ ensemble_size: int = 5,
184
+ processing_res: Optional[int] = None,
185
+ match_input_res: bool = True,
186
+ resample_method: str = "bilinear",
187
+ batch_size: int = 0,
188
+ generator: Union[torch.Generator, None] = None,
189
+ color_map: str = "Spectral",
190
+ show_progress_bar: bool = True,
191
+ ensemble_kwargs: Dict = None,
192
+ ) -> MarigoldDepthOutput:
193
+ """
194
+ Function invoked when calling the pipeline.
195
+
196
+ Args:
197
+ input_image (`Image`):
198
+ Input RGB (or gray-scale) image.
199
+ denoising_steps (`int`, *optional*, defaults to `None`):
200
+ Number of denoising diffusion steps during inference. The default value `None` results in automatic
201
+ selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
202
+ for Marigold-LCM models.
203
+ ensemble_size (`int`, *optional*, defaults to `10`):
204
+ Number of predictions to be ensembled.
205
+ processing_res (`int`, *optional*, defaults to `None`):
206
+ Effective processing resolution. When set to `0`, processes at the original image resolution. This
207
+ produces crisper predictions, but may also lead to the overall loss of global context. The default
208
+ value `None` resolves to the optimal value from the model config.
209
+ match_input_res (`bool`, *optional*, defaults to `True`):
210
+ Resize depth prediction to match input resolution.
211
+ Only valid if `processing_res` > 0.
212
+ resample_method: (`str`, *optional*, defaults to `bilinear`):
213
+ Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
214
+ batch_size (`int`, *optional*, defaults to `0`):
215
+ Inference batch size, no bigger than `num_ensemble`.
216
+ If set to 0, the script will automatically decide the proper batch size.
217
+ generator (`torch.Generator`, *optional*, defaults to `None`)
218
+ Random generator for initial noise generation.
219
+ show_progress_bar (`bool`, *optional*, defaults to `True`):
220
+ Display a progress bar of diffusion denoising.
221
+ color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
222
+ Colormap used to colorize the depth map.
223
+ scale_invariant (`str`, *optional*, defaults to `True`):
224
+ Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction.
225
+ shift_invariant (`str`, *optional*, defaults to `True`):
226
+ Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False, near plane will be fixed at 0m.
227
+ ensemble_kwargs (`dict`, *optional*, defaults to `None`):
228
+ Arguments for detailed ensembling settings.
229
+ Returns:
230
+ `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
231
+ - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
232
+ - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
233
+ - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
234
+ coming from ensembling. None if `ensemble_size = 1`
235
+ """
236
+ # Model-specific optimal default values leading to fast and reasonable results.
237
+ if denoising_steps is None:
238
+ denoising_steps = self.default_denoising_steps
239
+ if processing_res is None:
240
+ processing_res = self.default_processing_resolution
241
+
242
+ assert processing_res >= 0
243
+ assert ensemble_size >= 1
244
+
245
+ # Check if denoising step is reasonable
246
+ self._check_inference_step(denoising_steps)
247
+
248
+ resample_method: InterpolationMode = get_tv_resample_method(resample_method)
249
+
250
+ # ----------------- Image Preprocess -----------------
251
+ # Convert to torch tensor
252
+ if isinstance(input_image, Image.Image):
253
+ input_image = input_image.convert("RGB")
254
+ # convert to torch tensor [H, W, rgb] -> [rgb, H, W]
255
+ rgb = pil_to_tensor(input_image)
256
+ rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
257
+ elif isinstance(input_image, torch.Tensor):
258
+ rgb = input_image
259
+ else:
260
+ raise TypeError(f"Unknown input type: {type(input_image) = }")
261
+ input_size = rgb.shape
262
+ assert (
263
+ 4 == rgb.dim() and 3 == input_size[-3]
264
+ ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
265
+
266
+ # Resize image
267
+ if processing_res > 0:
268
+ rgb = resize_max_res(
269
+ rgb,
270
+ max_edge_resolution=processing_res,
271
+ resample_method=resample_method,
272
+ )
273
+
274
+ # Normalize rgb values
275
+ rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
276
+ rgb_norm = rgb_norm.to(self.dtype)
277
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
278
+
279
+ # ----------------- Predicting depth -----------------
280
+ # Batch repeated input image
281
+ duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
282
+ single_rgb_dataset = TensorDataset(duplicated_rgb)
283
+ if batch_size > 0:
284
+ _bs = batch_size
285
+ else:
286
+ _bs = find_batch_size(
287
+ ensemble_size=ensemble_size,
288
+ input_res=max(rgb_norm.shape[1:]),
289
+ dtype=self.dtype,
290
+ )
291
+
292
+ single_rgb_loader = DataLoader(
293
+ single_rgb_dataset, batch_size=_bs, shuffle=False
294
+ )
295
+
296
+ # Predict depth maps (batched)
297
+ depth_pred_ls = []
298
+ if show_progress_bar:
299
+ iterable = tqdm(
300
+ single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
301
+ )
302
+ else:
303
+ iterable = single_rgb_loader
304
+ for batch in iterable:
305
+ (batched_img,) = batch
306
+ depth_pred_raw = self.single_infer(
307
+ rgb_in=batched_img,
308
+ num_inference_steps=denoising_steps,
309
+ show_pbar=show_progress_bar,
310
+ generator=generator,
311
+ )
312
+ depth_pred_ls.append(depth_pred_raw.detach())
313
+ depth_preds = torch.concat(depth_pred_ls, dim=0)
314
+ torch.cuda.empty_cache() # clear vram cache for ensembling
315
+
316
+ # ----------------- Test-time ensembling -----------------
317
+ if ensemble_size > 1:
318
+ depth_pred, pred_uncert = ensemble_depth(
319
+ depth_preds,
320
+ scale_invariant=self.scale_invariant,
321
+ shift_invariant=self.shift_invariant,
322
+ max_res=50,
323
+ **(ensemble_kwargs or {}),
324
+ )
325
+ else:
326
+ depth_pred = depth_preds
327
+ pred_uncert = None
328
+
329
+ # Resize back to original resolution
330
+ if match_input_res:
331
+ depth_pred = resize(
332
+ depth_pred,
333
+ input_size[-2:],
334
+ interpolation=resample_method,
335
+ antialias=True,
336
+ )
337
+
338
+ # Convert to numpy
339
+ depth_pred = depth_pred.squeeze()
340
+ depth_pred = depth_pred.cpu().numpy()
341
+ if pred_uncert is not None:
342
+ pred_uncert = pred_uncert.squeeze().cpu().numpy()
343
+
344
+ # Clip output range
345
+ depth_pred = depth_pred.clip(0, 1)
346
+
347
+ # Colorize
348
+ if color_map is not None:
349
+ depth_colored = colorize_depth_maps(
350
+ depth_pred, 0, 1, cmap=color_map
351
+ ).squeeze() # [3, H, W], value in (0, 1)
352
+ depth_colored = (depth_colored * 255).astype(np.uint8)
353
+ depth_colored_hwc = chw2hwc(depth_colored)
354
+ depth_colored_img = Image.fromarray(depth_colored_hwc)
355
+ else:
356
+ depth_colored_img = None
357
+
358
+ return MarigoldDepthOutput(
359
+ depth_np=depth_pred,
360
+ depth_colored=depth_colored_img,
361
+ uncertainty=pred_uncert,
362
+ )
363
+
364
+ def _replace_unet_conv_in(self):
365
+ # replace the first layer to accept 8 in_channels
366
+ _weight = self.unet.conv_in.weight.clone() # [320, 4, 3, 3]
367
+ _bias = self.unet.conv_in.bias.clone() # [320]
368
+ zero_weight = torch.zeros(_weight.shape).to(_weight.device)
369
+ _weight = torch.cat([_weight, zero_weight], dim=1)
370
+ # _weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s)
371
+ # half the activation magnitude
372
+ # _weight *= 0.5
373
+ # new conv_in channel
374
+ _n_convin_out_channel = self.unet.conv_in.out_channels
375
+ _new_conv_in = Conv2d(
376
+ 8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
377
+ )
378
+ _new_conv_in.weight = Parameter(_weight)
379
+ _new_conv_in.bias = Parameter(_bias)
380
+ self.unet.conv_in = _new_conv_in
381
+ logging.info("Unet conv_in layer is replaced")
382
+ # replace config
383
+ self.unet.config["in_channels"] = 8
384
+ logging.info("Unet config is updated")
385
+ return
386
+
387
+ def _replace_unet_conv_out(self):
388
+ # replace the first layer to accept 8 in_channels
389
+ _weight = self.unet.conv_out.weight.clone() # [8, 320, 3, 3]
390
+ _bias = self.unet.conv_out.bias.clone() # [320]
391
+ _weight = _weight.repeat((2, 1, 1, 1)) # Keep selected channel(s)
392
+ _bias = _bias.repeat((2))
393
+ # half the activation magnitude
394
+
395
+ # new conv_in channel
396
+ _n_convin_out_channel = self.unet.conv_out.out_channels
397
+ _new_conv_out = Conv2d(
398
+ _n_convin_out_channel, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
399
+ )
400
+ _new_conv_out.weight = Parameter(_weight)
401
+ _new_conv_out.bias = Parameter(_bias)
402
+ self.unet.conv_out = _new_conv_out
403
+ logging.info("Unet conv_out layer is replaced")
404
+ # replace config
405
+ self.unet.config["out_channels"] = 8
406
+ logging.info("Unet config is updated")
407
+ return
408
+
409
+ def _check_inference_step(self, n_step: int) -> None:
410
+ """
411
+ Check if denoising step is reasonable
412
+ Args:
413
+ n_step (`int`): denoising steps
414
+ """
415
+ assert n_step >= 1
416
+
417
+ if isinstance(self.scheduler, DDIMScheduler):
418
+ if n_step < 10:
419
+ logging.warning(
420
+ f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference."
421
+ )
422
+ elif isinstance(self.scheduler, LCMScheduler):
423
+ if not 1 <= n_step <= 4:
424
+ logging.warning(
425
+ f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps."
426
+ )
427
+ else:
428
+ raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
429
+
430
+ def encode_text(self, prompt):
431
+ """
432
+ Encode text embedding for empty prompt
433
+ """
434
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
435
+ text_encoders = (
436
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
437
+ )
438
+ prompts = [prompt, prompt]
439
+ prompt_embeds_list = []
440
+
441
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
442
+ text_inputs = tokenizer(
443
+ prompt,
444
+ padding="max_length",
445
+ max_length=tokenizer.model_max_length,
446
+ truncation=True,
447
+ return_tensors="pt",
448
+ )
449
+
450
+ text_input_ids = text_inputs.input_ids
451
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
452
+
453
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
454
+ text_input_ids, untruncated_ids
455
+ ):
456
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1: -1])
457
+ print(
458
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
459
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
460
+ )
461
+
462
+ prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True)
463
+
464
+ pooled_prompt_embeds = prompt_embeds[0]
465
+ prompt_embeds = prompt_embeds.hidden_states[-2]
466
+ prompt_embeds_list.append(prompt_embeds)
467
+
468
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
469
+
470
+ return prompt_embeds, pooled_prompt_embeds
471
+
472
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
473
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
474
+
475
+ passed_add_embed_dim = (
476
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
477
+ )
478
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
479
+
480
+ if expected_add_embed_dim != passed_add_embed_dim:
481
+ raise ValueError(
482
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
483
+ )
484
+
485
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
486
+ return add_time_ids
487
+
488
+ def numpy_to_pil(self, images: np.ndarray) -> PIL.Image.Image:
489
+ """
490
+ Convert a numpy image or a batch of images to a PIL image.
491
+ """
492
+ if images.ndim == 3:
493
+ images = images[None, ...]
494
+ images = (images * 255).round().astype("uint8")
495
+ if images.shape[-1] == 1:
496
+ # special case for grayscale (single channel) images
497
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
498
+ else:
499
+ pil_images = [Image.fromarray(image) for image in images]
500
+
501
+ return pil_images
502
+
503
+ @torch.no_grad()
504
+ def generate_rgbd(
505
+ self,
506
+ prompt: str or list,
507
+ num_inference_steps: int,
508
+ generator: Union[torch.Generator, None],
509
+ show_pbar: bool = None,
510
+ negative_prompt: str or list = '',
511
+ color_map: str = "Spectral",
512
+ height: int = 1024,
513
+ width: int = 1024,
514
+ guidance_scale: float = 5.5
515
+ ):
516
+ """
517
+ Perform an individual depth prediction without ensembling.
518
+
519
+ Args:
520
+ rgb_in (`torch.Tensor`):
521
+ Input RGB image.
522
+ num_inference_steps (`int`):
523
+ Number of diffusion denoisign steps (DDIM) during inference.
524
+ show_pbar (`bool`):
525
+ Display a progress bar of diffusion denoising.
526
+ generator (`torch.Generator`)
527
+ Random generator for initial noise generation.
528
+ Returns:
529
+ `torch.Tensor`: Predicted depth map.
530
+ """
531
+ device = self.device
532
+ ori_type = self.dtype
533
+
534
+ height = height or self.default_sample_size * self.vae_scale_factor
535
+ width = width or self.default_sample_size * self.vae_scale_factor
536
+
537
+ original_size = (height, width)
538
+ target_size = (height, width)
539
+
540
+ # Set timesteps
541
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
542
+ timesteps = self.scheduler.timesteps # [T]
543
+
544
+ # prepare text embeddings
545
+ if isinstance(prompt, list):
546
+ bs = len(prompt)
547
+ batch_text_embed = []
548
+ batch_pooled_text_embed = []
549
+ for p in prompt:
550
+ prompt_embed, pooled_prompt_embed = self.encode_text(p)
551
+ batch_text_embed.append(prompt_embed)
552
+ batch_pooled_text_embed.append(pooled_prompt_embed)
553
+ batch_text_embed = torch.cat(batch_text_embed, dim=0)
554
+ batch_pooled_text_embed = torch.cat(batch_pooled_text_embed, dim=0)
555
+ elif isinstance(prompt, str):
556
+ bs = 1
557
+ batch_text_embed, batch_pooled_text_embed = self.encode_text(prompt)
558
+ else:
559
+ raise NotImplementedError
560
+
561
+ batch_empty_text_embed = torch.zeros_like(batch_text_embed).to(device) # [B, 77, d]
562
+ batch_pooled_empty_text_embed = torch.zeros_like(batch_pooled_text_embed).to(device)
563
+
564
+ # prepare added time ids & embeddings
565
+ add_time_ids = self._get_add_time_ids(
566
+ original_size, (0, 0), target_size, dtype=batch_text_embed.dtype
567
+ )
568
+ negative_add_time_ids = add_time_ids
569
+
570
+ prompt_embeds = torch.cat([batch_empty_text_embed, batch_text_embed], dim=0)
571
+ add_text_embeds = torch.cat([batch_pooled_empty_text_embed, batch_pooled_text_embed], dim=0)
572
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
573
+
574
+ prompt_embeds = prompt_embeds.to(device).to(torch.bfloat16)
575
+ add_text_embeds = add_text_embeds.to(device).to(torch.bfloat16)
576
+ add_time_ids = add_time_ids.to(device).repeat(bs, 1)
577
+
578
+ # Initial depth map (noise)
579
+ cat_latent = torch.randn(
580
+ [bs, self.unet.config["in_channels"], height // self.vae_scale_factor, width // self.vae_scale_factor],
581
+ device=device,
582
+ dtype=torch.bfloat16,
583
+ generator=generator,
584
+ ) # [B, 8, h, w]
585
+ cat_latent = cat_latent * self.scheduler.init_noise_sigma
586
+
587
+ # Denoising loop
588
+ if show_pbar:
589
+ iterable = tqdm(
590
+ enumerate(timesteps),
591
+ total=len(timesteps),
592
+ leave=False,
593
+ desc=" " * 4 + "Diffusion denoising",
594
+ )
595
+ else:
596
+ iterable = enumerate(timesteps)
597
+
598
+ self.to(torch.bfloat16)
599
+
600
+ for i, t in iterable:
601
+ latent_model_input = torch.cat([cat_latent] * 2)
602
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
603
+
604
+ # predict the noise residual
605
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
606
+ with torch.no_grad():
607
+ noise_pred = self.unet(
608
+ latent_model_input,
609
+ t,
610
+ t,
611
+ encoder_hidden_states=prompt_embeds,
612
+ added_cond_kwargs=added_cond_kwargs,
613
+ separate_list=self.separate_list,
614
+ return_dict=False,
615
+ )[0]
616
+
617
+ # perform guidance
618
+ guidance_scale = guidance_scale
619
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
620
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
621
+
622
+ # compute the previous noisy sample x_t -> x_t-1
623
+ cat_latent = self.scheduler.step(noise_pred, t, cat_latent, generator=generator).prev_sample
624
+
625
+ # self.unet.to(default_dtype)
626
+ # cat_latent.to(default_dtype)
627
+ image = self.decode_image(cat_latent[:, 0:4, :, :])
628
+ image = self.numpy_to_pil(image)
629
+
630
+ depth = self.decode_depth(cat_latent[:, 4:, :, :])
631
+ depth = torch.clip(depth, -1.0, 1.0)
632
+ depth = (depth + 1.0) / 2.0
633
+ depth_pred = depth.squeeze()
634
+ depth_pred = depth_pred.float().cpu().numpy()
635
+ depth_pred = depth_pred.clip(0, 1)
636
+
637
+ # Colorize
638
+ if color_map is not None:
639
+ depth_colored_img = []
640
+ depth_colored = colorize_depth_maps(
641
+ depth_pred, 0, 1, cmap=color_map
642
+ ).squeeze() # [3, H, W], value in (0, 1)
643
+ depth_colored_img = self.numpy_to_pil(np.transpose(depth_colored, (0, 2, 3, 1)))
644
+ else:
645
+ depth_colored_img = None
646
+
647
+ rgbd_images = self.post_process_rgbd(prompt, image, depth_colored_img)
648
+ self.to(ori_type)
649
+
650
+ return rgbd_images
651
+
652
+ @torch.no_grad()
653
+ def image2depth(self,
654
+ input_image: Union[Image.Image, torch.Tensor],
655
+ denoising_steps: Optional[int] = None,
656
+ ensemble_size: int = 5,
657
+ processing_res: Optional[int] = None,
658
+ match_input_res: bool = True,
659
+ resample_method: str = "bilinear",
660
+ batch_size: int = 0,
661
+ generator: Union[torch.Generator, None] = None,
662
+ color_map: str = "Spectral",
663
+ show_progress_bar: bool = True,
664
+ ensemble_kwargs: Dict = None,
665
+ ):
666
+ # Model-specific optimal default values leading to fast and reasonable results.
667
+ if denoising_steps is None:
668
+ denoising_steps = self.default_denoising_steps
669
+ if processing_res is None:
670
+ processing_res = self.default_processing_resolution
671
+
672
+ ori_type = self.dtype
673
+ self.to(torch.bfloat16)
674
+
675
+ assert processing_res >= 0
676
+ assert ensemble_size >= 1
677
+
678
+ # Check if denoising step is reasonable
679
+ self._check_inference_step(denoising_steps)
680
+
681
+ resample_method: InterpolationMode = get_tv_resample_method(resample_method)
682
+
683
+ # ----------------- Image Preprocess -----------------
684
+ # Convert to torch tensor
685
+ if isinstance(input_image, Image.Image):
686
+ input_image = input_image.convert("RGB")
687
+ # convert to torch tensor [H, W, rgb] -> [rgb, H, W]
688
+ rgb = pil_to_tensor(input_image)
689
+ rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
690
+ elif isinstance(input_image, torch.Tensor):
691
+ rgb = input_image
692
+ else:
693
+ raise TypeError(f"Unknown input type: {type(input_image) = }")
694
+ input_size = rgb.shape
695
+ assert (
696
+ 4 == rgb.dim() and 3 == input_size[-3]
697
+ ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
698
+
699
+ # Resize image
700
+ if processing_res > 0:
701
+ rgb = resize_max_res(
702
+ rgb,
703
+ max_edge_resolution=processing_res,
704
+ resample_method=resample_method,
705
+ )
706
+
707
+ # Normalize rgb values
708
+ rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
709
+ rgb_norm = rgb_norm.to(self.dtype)
710
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
711
+
712
+ # ----------------- Predicting depth -----------------
713
+ # Batch repeated input image
714
+ duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
715
+ single_rgb_dataset = TensorDataset(duplicated_rgb)
716
+ if batch_size > 0:
717
+ _bs = batch_size
718
+ else:
719
+ _bs = find_batch_size(
720
+ ensemble_size=ensemble_size,
721
+ input_res=max(rgb_norm.shape[1:]),
722
+ dtype=self.dtype,
723
+ )
724
+
725
+ single_rgb_loader = DataLoader(
726
+ single_rgb_dataset, batch_size=_bs, shuffle=False
727
+ )
728
+
729
+ # Predict depth maps (batched)
730
+ depth_pred_ls = []
731
+ if show_progress_bar:
732
+ iterable = tqdm(
733
+ single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
734
+ )
735
+ else:
736
+ iterable = single_rgb_loader
737
+ for batch in iterable:
738
+ (batched_img,) = batch
739
+ depth_pred_raw = self.single_image2depth(
740
+ rgb_in=batched_img,
741
+ num_inference_steps=denoising_steps,
742
+ show_pbar=show_progress_bar,
743
+ generator=generator,
744
+ )
745
+ depth_pred_ls.append(depth_pred_raw.detach())
746
+ depth_preds = torch.concat(depth_pred_ls, dim=0)
747
+ torch.cuda.empty_cache() # clear vram cache for ensembling
748
+ depth_preds = depth_preds.to(torch.float32)
749
+ # ----------------- Test-time ensembling -----------------
750
+ if ensemble_size > 1:
751
+ depth_pred, pred_uncert = ensemble_depth(
752
+ depth_preds,
753
+ scale_invariant=self.scale_invariant,
754
+ shift_invariant=self.shift_invariant,
755
+ max_res=50,
756
+ **(ensemble_kwargs or {}),
757
+ )
758
+ else:
759
+ depth_pred = depth_preds
760
+ pred_uncert = None
761
+
762
+ # Resize back to original resolution
763
+ if match_input_res:
764
+ depth_pred = resize(
765
+ depth_pred,
766
+ input_size[-2:],
767
+ interpolation=resample_method,
768
+ antialias=True,
769
+ )
770
+
771
+ # Convert to numpy
772
+ depth_pred = depth_pred.squeeze()
773
+ depth_pred = depth_pred.cpu().numpy()
774
+ if pred_uncert is not None:
775
+ pred_uncert = pred_uncert.squeeze().cpu().numpy()
776
+
777
+ # Clip output range
778
+ depth_pred = depth_pred.clip(0, 1)
779
+
780
+ # Colorize
781
+ if color_map is not None:
782
+ depth_colored = colorize_depth_maps(
783
+ depth_pred, 0, 1, cmap=color_map
784
+ ).squeeze() # [3, H, W], value in (0, 1)
785
+ depth_colored = (depth_colored * 255).astype(np.uint8)
786
+ depth_colored_hwc = chw2hwc(depth_colored)
787
+ depth_colored_img = Image.fromarray(depth_colored_hwc)
788
+ else:
789
+ depth_colored_img = None
790
+
791
+ self.to(ori_type)
792
+
793
+ return MarigoldDepthOutput(
794
+ depth_np=depth_pred,
795
+ depth_colored=depth_colored_img,
796
+ uncertainty=pred_uncert,
797
+ )
798
+
799
+ @torch.no_grad()
800
+ def single_image2depth(
801
+ self,
802
+ rgb_in: torch.Tensor,
803
+ num_inference_steps: int,
804
+ generator: Union[torch.Generator, None],
805
+ show_pbar: bool
806
+ ) -> torch.Tensor:
807
+ """
808
+ Perform an individual depth prediction without ensembling.
809
+
810
+ Args:
811
+ rgb_in (`torch.Tensor`):
812
+ Input RGB image.
813
+ num_inference_steps (`int`):
814
+ Number of diffusion denoisign steps (DDIM) during inference.
815
+ show_pbar (`bool`):
816
+ Display a progress bar of diffusion denoising.
817
+ generator (`torch.Generator`)
818
+ Random generator for initial noise generation.
819
+ Returns:
820
+ `torch.Tensor`: Predicted depth map.
821
+ """
822
+ device = self.device
823
+ rgb_in = rgb_in.to(device)
824
+
825
+ # Set timesteps
826
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
827
+ timesteps = self.scheduler.timesteps # [T]
828
+ # Encode image
829
+ rgb_latent = self.encode_rgb(rgb_in)
830
+
831
+ # Initial depth map (noise)
832
+ depth_latent = torch.randn(
833
+ rgb_latent.shape,
834
+ device=device,
835
+ dtype=self.dtype,
836
+ generator=generator,
837
+ ) # [B, 4, h, w]
838
+
839
+ # Batched empty text embedding
840
+ if self.empty_text_embed is None:
841
+ self.encode_empty_text()
842
+ batch_empty_text_embed = self.empty_text_embed.repeat(
843
+ (rgb_latent.shape[0], 1, 1)
844
+ ).to(device).to(self.dtype) # [B, 2, 1024]
845
+
846
+ # Denoising loop
847
+ if show_pbar:
848
+ iterable = tqdm(
849
+ enumerate(timesteps),
850
+ total=len(timesteps),
851
+ leave=False,
852
+ desc=" " * 4 + "Diffusion denoising",
853
+ )
854
+ else:
855
+ iterable = enumerate(timesteps)
856
+
857
+ for i, t in iterable:
858
+ unet_input = torch.cat(
859
+ [rgb_latent, depth_latent], dim=1
860
+ ) # this order is important
861
+ # predict the noise residual
862
+ noise_pred = self.unet(
863
+ unet_input, rgb_timestep=0, depth_timestep=t, encoder_hidden_states=batch_empty_text_embed
864
+ ).sample # [B, 4, h, w]
865
+
866
+ # compute the previous noisy sample x_t -> x_t-1
867
+ depth_latent = self.scheduler.step(
868
+ noise_pred[:, 4:, :, :], t, depth_latent, generator=generator
869
+ ).prev_sample
870
+
871
+ depth = self.decode_depth(depth_latent)
872
+
873
+ # clip prediction
874
+ depth = torch.clip(depth, -1.0, 1.0)
875
+ # shift to [0, 1]
876
+ depth = (depth + 1.0) / 2.0
877
+
878
+ return depth
879
+
880
+ def single_depth2image(
881
+ self,
882
+ depth_in: torch.Tensor,
883
+ prompt,
884
+ num_inference_steps: int,
885
+ generator: Union[torch.Generator, None],
886
+ show_pbar: bool
887
+ ) -> torch.Tensor:
888
+ """
889
+ Perform an individual depth prediction without ensembling.
890
+
891
+ Args:
892
+ rgb_in (`torch.Tensor`):
893
+ Input RGB image.
894
+ num_inference_steps (`int`):
895
+ Number of diffusion denoisign steps (DDIM) during inference.
896
+ show_pbar (`bool`):
897
+ Display a progress bar of diffusion denoising.
898
+ generator (`torch.Generator`)
899
+ Random generator for initial noise generation.
900
+ Returns:
901
+ `torch.Tensor`: Predicted depth map.
902
+ """
903
+ device = self.device
904
+
905
+ depth_in = depth_in.to(device)
906
+
907
+ # Set timesteps
908
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
909
+ timesteps = self.scheduler.timesteps # [T]
910
+
911
+ # Encode depth
912
+ depth_latent = self.encode_rgb(depth_in)
913
+
914
+ # Initial rgb map (noise)
915
+ rgb_latent = torch.randn(
916
+ depth_latent.shape,
917
+ device=device,
918
+ dtype=self.dtype,
919
+ generator=generator,
920
+ ) # [B, 4, h, w]
921
+
922
+ # encode text
923
+ prompt_embed, pooled_prompt_embed = self.model.encode_text(prompt)
924
+
925
+ # Denoising loop
926
+ if show_pbar:
927
+ iterable = tqdm(
928
+ enumerate(timesteps),
929
+ total=len(timesteps),
930
+ leave=False,
931
+ desc=" " * 4 + "Diffusion denoising",
932
+ )
933
+ else:
934
+ iterable = enumerate(timesteps)
935
+
936
+ for i, t in iterable:
937
+ unet_input = torch.cat(
938
+ [rgb_latent, depth_latent], dim=1
939
+ ) # this order is important
940
+
941
+ # predict the noise residual
942
+ noise_pred = self.unet(
943
+ unet_input, rgb_timestep=t, depth_timestep=0, encoder_hidden_states=batch_text_embed
944
+ ).sample # [B, 4, h, w]
945
+
946
+ # compute the previous noisy sample x_t -> x_t-1
947
+ rgb_latent = self.scheduler.step(
948
+ noise_pred[:, 0:4, :, :], t, rgb_latent, generator=generator
949
+ ).prev_sample
950
+
951
+ image = self.decode_image(rgb_latent)
952
+
953
+ return image
954
+
955
+ def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
956
+ """
957
+ Encode RGB image into latent.
958
+
959
+ Args:
960
+ rgb_in (`torch.Tensor`):
961
+ Input RGB image to be encoded.
962
+
963
+ Returns:
964
+ `torch.Tensor`: Image latent.
965
+ """
966
+ # encode
967
+ h = self.vae.encoder(rgb_in)
968
+ moments = self.vae.quant_conv(h)
969
+ mean, logvar = torch.chunk(moments, 2, dim=1)
970
+ # scale latent
971
+ rgb_latent = mean * self.rgb_latent_scale_factor
972
+ return rgb_latent
973
+
974
+ def encode_depth(self, depth_in: torch.Tensor) -> torch.Tensor:
975
+ """
976
+ Encode RGB image into latent.
977
+
978
+ Args:
979
+ rgb_in (`torch.Tensor`):
980
+ Input RGB image to be encoded.
981
+
982
+ Returns:
983
+ `torch.Tensor`: Image latent.
984
+ """
985
+ # encode
986
+ h = self.vae.encoder(depth_in)
987
+ moments = self.vae.quant_conv(h)
988
+ mean, logvar = torch.chunk(moments, 2, dim=1)
989
+ # scale latent
990
+ rgb_latent = mean * self.rgb_latent_scale_factor
991
+ return rgb_latent
992
+
993
+ def decode_image(self, latents):
994
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
995
+ image = self.image_processor.postprocess(image, output_type='np')
996
+ return image
997
+
998
+ def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
999
+ """
1000
+ Decode depth latent into depth map.
1001
+
1002
+ Args:
1003
+ depth_latent (`torch.Tensor`):
1004
+ Depth latent to be decoded.
1005
+
1006
+ Returns:
1007
+ `torch.Tensor`: Decoded depth map.
1008
+ """
1009
+ # scale latent
1010
+ depth_latent = depth_latent / self.depth_latent_scale_factor
1011
+ # decode
1012
+ z = self.vae.post_quant_conv(depth_latent)
1013
+ stacked = self.vae.decoder(z)
1014
+ # mean of output channels
1015
+ depth_mean = stacked.mean(dim=1, keepdim=True)
1016
+ return depth_mean
1017
+
1018
+ def post_process_rgbd(self, prompts, rgb_image, depth_image):
1019
+
1020
+ rgbd_images = []
1021
+ for idx, p in enumerate(prompts):
1022
+ image1, image2 = rgb_image[idx], depth_image[idx]
1023
+
1024
+ width1, height1 = image1.size
1025
+ width2, height2 = image2.size
1026
+
1027
+ font = ImageFont.load_default(size=20)
1028
+ text = p
1029
+ draw = ImageDraw.Draw(image1)
1030
+ text_bbox = draw.textbbox((0, 0), text, font=font)
1031
+ text_width = text_bbox[2] - text_bbox[0]
1032
+ text_height = text_bbox[3] - text_bbox[1]
1033
+
1034
+ new_image = Image.new('RGB', (width1 + width2, max(height1, height2) + text_height), (255, 255, 255))
1035
+
1036
+ text_x = (new_image.width - text_width) // 2
1037
+ text_y = 0
1038
+ draw = ImageDraw.Draw(new_image)
1039
+ draw.text((text_x, text_y), text, fill="black", font=font)
1040
+
1041
+ new_image.paste(image1, (0, text_height))
1042
+ new_image.paste(image2, (width1, text_height))
1043
+
1044
+ rgbd_images.append(pil_to_tensor(new_image))
1045
+
1046
+ return rgbd_images
marigold/pipeline_stable_diffusion_inpaint.py ADDED
@@ -0,0 +1,1068 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import PIL.Image
20
+ import torch
21
+ from packaging import version
22
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
23
+
24
+ from diffusers.configuration_utils import FrozenDict
25
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
26
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27
+ from diffusers.models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel
28
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
29
+ from diffusers.schedulers import KarrasDiffusionSchedulers
30
+ from diffusers.utils import deprecate, logging, scale_lora_layers, unscale_lora_layers
31
+ from diffusers.utils.torch_utils import randn_tensor
32
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
33
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
34
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
41
+ """
42
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
43
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
44
+ ``image`` and ``1`` for the ``mask``.
45
+
46
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
47
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
48
+
49
+ Args:
50
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
51
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
52
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
53
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
54
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
55
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
56
+
57
+
58
+ Raises:
59
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
60
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
61
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
62
+ (ot the other way around).
63
+
64
+ Returns:
65
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
66
+ dimensions: ``batch x channels x height x width``.
67
+ """
68
+ deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
69
+ deprecate(
70
+ "prepare_mask_and_masked_image",
71
+ "0.30.0",
72
+ deprecation_message,
73
+ )
74
+ if image is None:
75
+ raise ValueError("`image` input cannot be undefined.")
76
+
77
+ if mask is None:
78
+ raise ValueError("`mask_image` input cannot be undefined.")
79
+
80
+ if isinstance(image, torch.Tensor):
81
+ if not isinstance(mask, torch.Tensor):
82
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
83
+
84
+ # Batch single image
85
+ if image.ndim == 3:
86
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
87
+ image = image.unsqueeze(0)
88
+
89
+ # Batch and add channel dim for single mask
90
+ if mask.ndim == 2:
91
+ mask = mask.unsqueeze(0).unsqueeze(0)
92
+
93
+ # Batch single mask or add channel dim
94
+ if mask.ndim == 3:
95
+ # Single batched mask, no channel dim or single mask not batched but channel dim
96
+ if mask.shape[0] == 1:
97
+ mask = mask.unsqueeze(0)
98
+
99
+ # Batched masks no channel dim
100
+ else:
101
+ mask = mask.unsqueeze(1)
102
+
103
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
104
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
105
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
106
+
107
+ # Check image is in [-1, 1]
108
+ if image.min() < -1 or image.max() > 1:
109
+ raise ValueError("Image should be in [-1, 1] range")
110
+
111
+ # Check mask is in [0, 1]
112
+ if mask.min() < 0 or mask.max() > 1:
113
+ raise ValueError("Mask should be in [0, 1] range")
114
+
115
+ # Binarize mask
116
+ mask[mask < 0.5] = 0
117
+ mask[mask >= 0.5] = 1
118
+
119
+ # Image as float32
120
+ image = image.to(dtype=torch.float32)
121
+ elif isinstance(mask, torch.Tensor):
122
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
123
+ else:
124
+ # preprocess image
125
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
126
+ image = [image]
127
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
128
+ # resize all images w.r.t passed height an width
129
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
130
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
131
+ image = np.concatenate(image, axis=0)
132
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
133
+ image = np.concatenate([i[None, :] for i in image], axis=0)
134
+
135
+ image = image.transpose(0, 3, 1, 2)
136
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
137
+
138
+ # preprocess mask
139
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
140
+ mask = [mask]
141
+
142
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
143
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
144
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
145
+ mask = mask.astype(np.float32) / 255.0
146
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
147
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
148
+
149
+ mask[mask < 0.5] = 0
150
+ mask[mask >= 0.5] = 1
151
+ mask = torch.from_numpy(mask)
152
+
153
+ masked_image = image * (mask < 0.5)
154
+
155
+ # n.b. ensure backwards compatibility as old function does not return image
156
+ if return_image:
157
+ return mask, masked_image, image
158
+
159
+ return mask, masked_image
160
+
161
+
162
+ class StableDiffusionInpaintPipeline(
163
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
164
+ ):
165
+ r"""
166
+ Pipeline for text-guided image inpainting using Stable Diffusion.
167
+
168
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
169
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
170
+
171
+ The pipeline also inherits the following loading methods:
172
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
173
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
174
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
175
+
176
+ Args:
177
+ vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]):
178
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
179
+ text_encoder ([`CLIPTextModel`]):
180
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
181
+ tokenizer ([`~transformers.CLIPTokenizer`]):
182
+ A `CLIPTokenizer` to tokenize text.
183
+ unet ([`UNet2DConditionModel`]):
184
+ A `UNet2DConditionModel` to denoise the encoded image latents.
185
+ scheduler ([`SchedulerMixin`]):
186
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
187
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
188
+ safety_checker ([`StableDiffusionSafetyChecker`]):
189
+ Classification module that estimates whether generated images could be considered offensive or harmful.
190
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
191
+ about a model's potential harms.
192
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
193
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
194
+ """
195
+ model_cpu_offload_seq = "text_encoder->unet->vae"
196
+ _optional_components = ["safety_checker", "feature_extractor"]
197
+ _exclude_from_cpu_offload = ["safety_checker"]
198
+
199
+ def __init__(
200
+ self,
201
+ vae: Union[AutoencoderKL, AsymmetricAutoencoderKL],
202
+ text_encoder: CLIPTextModel,
203
+ tokenizer: CLIPTokenizer,
204
+ unet: UNet2DConditionModel,
205
+ scheduler: KarrasDiffusionSchedulers,
206
+ safety_checker: StableDiffusionSafetyChecker,
207
+ feature_extractor: CLIPImageProcessor,
208
+ requires_safety_checker: bool = True,
209
+ ):
210
+ super().__init__()
211
+
212
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
213
+ deprecation_message = (
214
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
215
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
216
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
217
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
218
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
219
+ " file"
220
+ )
221
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
222
+ new_config = dict(scheduler.config)
223
+ new_config["steps_offset"] = 1
224
+ scheduler._internal_dict = FrozenDict(new_config)
225
+
226
+ if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
227
+ deprecation_message = (
228
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration"
229
+ " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
230
+ " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to"
231
+ " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face"
232
+ " Hub, it would be very nice if you could open a Pull request for the"
233
+ " `scheduler/scheduler_config.json` file"
234
+ )
235
+ deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False)
236
+ new_config = dict(scheduler.config)
237
+ new_config["skip_prk_steps"] = True
238
+ scheduler._internal_dict = FrozenDict(new_config)
239
+
240
+ if safety_checker is None and requires_safety_checker:
241
+ logger.warning(
242
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
243
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
244
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
245
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
246
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
247
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
248
+ )
249
+
250
+ if safety_checker is not None and feature_extractor is None:
251
+ raise ValueError(
252
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
253
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
254
+ )
255
+
256
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
257
+ version.parse(unet.config._diffusers_version).base_version
258
+ ) < version.parse("0.9.0.dev0")
259
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
260
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
261
+ deprecation_message = (
262
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
263
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
264
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
265
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
266
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
267
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
268
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
269
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
270
+ " the `unet/config.json` file"
271
+ )
272
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
273
+ new_config = dict(unet.config)
274
+ new_config["sample_size"] = 64
275
+ unet._internal_dict = FrozenDict(new_config)
276
+
277
+ # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
278
+ if unet.config.in_channels != 9:
279
+ logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.")
280
+
281
+ self.register_modules(
282
+ vae=vae,
283
+ text_encoder=text_encoder,
284
+ tokenizer=tokenizer,
285
+ unet=unet,
286
+ scheduler=scheduler,
287
+ safety_checker=safety_checker,
288
+ feature_extractor=feature_extractor,
289
+ )
290
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
291
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
292
+ self.mask_processor = VaeImageProcessor(
293
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
294
+ )
295
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
296
+
297
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
298
+ def _encode_prompt(
299
+ self,
300
+ prompt,
301
+ device,
302
+ num_images_per_prompt,
303
+ do_classifier_free_guidance,
304
+ negative_prompt=None,
305
+ prompt_embeds: Optional[torch.FloatTensor] = None,
306
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
307
+ lora_scale: Optional[float] = None,
308
+ **kwargs,
309
+ ):
310
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
311
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
312
+
313
+ prompt_embeds_tuple = self.encode_prompt(
314
+ prompt=prompt,
315
+ device=device,
316
+ num_images_per_prompt=num_images_per_prompt,
317
+ do_classifier_free_guidance=do_classifier_free_guidance,
318
+ negative_prompt=negative_prompt,
319
+ prompt_embeds=prompt_embeds,
320
+ negative_prompt_embeds=negative_prompt_embeds,
321
+ lora_scale=lora_scale,
322
+ **kwargs,
323
+ )
324
+
325
+ # concatenate for backwards comp
326
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
327
+
328
+ return prompt_embeds
329
+
330
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
331
+ def encode_prompt(
332
+ self,
333
+ prompt,
334
+ device,
335
+ num_images_per_prompt,
336
+ do_classifier_free_guidance,
337
+ negative_prompt=None,
338
+ prompt_embeds: Optional[torch.FloatTensor] = None,
339
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
340
+ lora_scale: Optional[float] = None,
341
+ clip_skip: Optional[int] = None,
342
+ ):
343
+ r"""
344
+ Encodes the prompt into text encoder hidden states.
345
+
346
+ Args:
347
+ prompt (`str` or `List[str]`, *optional*):
348
+ prompt to be encoded
349
+ device: (`torch.device`):
350
+ torch device
351
+ num_images_per_prompt (`int`):
352
+ number of images that should be generated per prompt
353
+ do_classifier_free_guidance (`bool`):
354
+ whether to use classifier free guidance or not
355
+ negative_prompt (`str` or `List[str]`, *optional*):
356
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
357
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
358
+ less than `1`).
359
+ prompt_embeds (`torch.FloatTensor`, *optional*):
360
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
361
+ provided, text embeddings will be generated from `prompt` input argument.
362
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
363
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
364
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
365
+ argument.
366
+ lora_scale (`float`, *optional*):
367
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
368
+ clip_skip (`int`, *optional*):
369
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
370
+ the output of the pre-final layer will be used for computing the prompt embeddings.
371
+ """
372
+ # set lora scale so that monkey patched LoRA
373
+ # function of text encoder can correctly access it
374
+
375
+ if prompt is not None and isinstance(prompt, str):
376
+ batch_size = 1
377
+ elif prompt is not None and isinstance(prompt, list):
378
+ batch_size = len(prompt)
379
+ else:
380
+ batch_size = prompt_embeds.shape[0]
381
+
382
+ if prompt_embeds is None:
383
+ # textual inversion: procecss multi-vector tokens if necessary
384
+ if isinstance(self, TextualInversionLoaderMixin):
385
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
386
+
387
+ text_inputs = self.tokenizer(
388
+ prompt,
389
+ padding="max_length",
390
+ max_length=self.tokenizer.model_max_length,
391
+ truncation=True,
392
+ return_tensors="pt",
393
+ )
394
+ text_input_ids = text_inputs.input_ids
395
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
396
+
397
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
398
+ text_input_ids, untruncated_ids
399
+ ):
400
+ removed_text = self.tokenizer.batch_decode(
401
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
402
+ )
403
+ logger.warning(
404
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
405
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
406
+ )
407
+
408
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
409
+ attention_mask = text_inputs.attention_mask.to(device)
410
+ else:
411
+ attention_mask = None
412
+
413
+ if clip_skip is None:
414
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
415
+ prompt_embeds = prompt_embeds[0]
416
+ else:
417
+ prompt_embeds = self.text_encoder(
418
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
419
+ )
420
+ # Access the `hidden_states` first, that contains a tuple of
421
+ # all the hidden states from the encoder layers. Then index into
422
+ # the tuple to access the hidden states from the desired layer.
423
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
424
+ # We also need to apply the final LayerNorm here to not mess with the
425
+ # representations. The `last_hidden_states` that we typically use for
426
+ # obtaining the final prompt representations passes through the LayerNorm
427
+ # layer.
428
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
429
+
430
+ if self.text_encoder is not None:
431
+ prompt_embeds_dtype = self.text_encoder.dtype
432
+ elif self.unet is not None:
433
+ prompt_embeds_dtype = self.unet.dtype
434
+ else:
435
+ prompt_embeds_dtype = prompt_embeds.dtype
436
+
437
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
438
+
439
+ bs_embed, seq_len, _ = prompt_embeds.shape
440
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
441
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
442
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
443
+
444
+ # get unconditional embeddings for classifier free guidance
445
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
446
+ uncond_tokens: List[str]
447
+ if negative_prompt is None:
448
+ uncond_tokens = [""] * batch_size
449
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
450
+ raise TypeError(
451
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
452
+ f" {type(prompt)}."
453
+ )
454
+ elif isinstance(negative_prompt, str):
455
+ uncond_tokens = [negative_prompt]
456
+ elif batch_size != len(negative_prompt):
457
+ raise ValueError(
458
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
459
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
460
+ " the batch size of `prompt`."
461
+ )
462
+ else:
463
+ uncond_tokens = negative_prompt
464
+
465
+ # textual inversion: procecss multi-vector tokens if necessary
466
+ if isinstance(self, TextualInversionLoaderMixin):
467
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
468
+
469
+ max_length = prompt_embeds.shape[1]
470
+ uncond_input = self.tokenizer(
471
+ uncond_tokens,
472
+ padding="max_length",
473
+ max_length=max_length,
474
+ truncation=True,
475
+ return_tensors="pt",
476
+ )
477
+
478
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
479
+ attention_mask = uncond_input.attention_mask.to(device)
480
+ else:
481
+ attention_mask = None
482
+
483
+ negative_prompt_embeds = self.text_encoder(
484
+ uncond_input.input_ids.to(device),
485
+ attention_mask=attention_mask,
486
+ )
487
+ negative_prompt_embeds = negative_prompt_embeds[0]
488
+
489
+ if do_classifier_free_guidance:
490
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
491
+ seq_len = negative_prompt_embeds.shape[1]
492
+
493
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
494
+
495
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
496
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
497
+
498
+ return prompt_embeds, negative_prompt_embeds
499
+
500
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
501
+ def run_safety_checker(self, image, device, dtype):
502
+ if self.safety_checker is None:
503
+ has_nsfw_concept = None
504
+ else:
505
+ if torch.is_tensor(image):
506
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
507
+ else:
508
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
509
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
510
+ image, has_nsfw_concept = self.safety_checker(
511
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
512
+ )
513
+ return image, has_nsfw_concept
514
+
515
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
516
+ def prepare_extra_step_kwargs(self, generator, eta):
517
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
518
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
519
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
520
+ # and should be between [0, 1]
521
+
522
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
523
+ extra_step_kwargs = {}
524
+ if accepts_eta:
525
+ extra_step_kwargs["eta"] = eta
526
+
527
+ # check if the scheduler accepts generator
528
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
529
+ if accepts_generator:
530
+ extra_step_kwargs["generator"] = generator
531
+ return extra_step_kwargs
532
+
533
+ def check_inputs(
534
+ self,
535
+ prompt,
536
+ height,
537
+ width,
538
+ strength,
539
+ callback_steps,
540
+ negative_prompt=None,
541
+ prompt_embeds=None,
542
+ negative_prompt_embeds=None,
543
+ ):
544
+ if strength < 0 or strength > 1:
545
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
546
+
547
+ if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
548
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
549
+
550
+ if (callback_steps is None) or (
551
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
552
+ ):
553
+ raise ValueError(
554
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
555
+ f" {type(callback_steps)}."
556
+ )
557
+
558
+ if prompt is not None and prompt_embeds is not None:
559
+ raise ValueError(
560
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
561
+ " only forward one of the two."
562
+ )
563
+ elif prompt is None and prompt_embeds is None:
564
+ raise ValueError(
565
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
566
+ )
567
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
568
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
569
+
570
+ if negative_prompt is not None and negative_prompt_embeds is not None:
571
+ raise ValueError(
572
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
573
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
574
+ )
575
+
576
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
577
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
578
+ raise ValueError(
579
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
580
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
581
+ f" {negative_prompt_embeds.shape}."
582
+ )
583
+
584
+ def prepare_latents(
585
+ self,
586
+ batch_size,
587
+ num_channels_latents,
588
+ height,
589
+ width,
590
+ dtype,
591
+ device,
592
+ generator,
593
+ latents=None,
594
+ image=None,
595
+ timestep=None,
596
+ is_strength_max=True,
597
+ return_noise=False,
598
+ return_image_latents=False,
599
+ ):
600
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
601
+ if isinstance(generator, list) and len(generator) != batch_size:
602
+ raise ValueError(
603
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
604
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
605
+ )
606
+
607
+ if (image is None or timestep is None) and not is_strength_max:
608
+ raise ValueError(
609
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
610
+ "However, either the image or the noise timestep has not been provided."
611
+ )
612
+
613
+ if return_image_latents or (latents is None and not is_strength_max):
614
+ image = image.to(device=device, dtype=dtype)
615
+
616
+ if image.shape[1] == 4:
617
+ image_latents = image
618
+ else:
619
+ image_latents = self._encode_vae_image(image=image, generator=generator)
620
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
621
+
622
+ if latents is None:
623
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
624
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
625
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
626
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
627
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
628
+ else:
629
+ noise = latents.to(device)
630
+ latents = noise * self.scheduler.init_noise_sigma
631
+
632
+ outputs = (latents,)
633
+
634
+ if return_noise:
635
+ outputs += (noise,)
636
+
637
+ if return_image_latents:
638
+ outputs += (image_latents,)
639
+
640
+ return outputs
641
+
642
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
643
+ if isinstance(generator, list):
644
+ image_latents = [
645
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
646
+ for i in range(image.shape[0])
647
+ ]
648
+ image_latents = torch.cat(image_latents, dim=0)
649
+ else:
650
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
651
+
652
+ image_latents = self.vae.config.scaling_factor * image_latents
653
+
654
+ return image_latents
655
+
656
+ def prepare_mask_latents(
657
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
658
+ ):
659
+ # resize the mask to latents shape as we concatenate the mask to the latents
660
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
661
+ # and half precision
662
+ mask = torch.nn.functional.interpolate(
663
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
664
+ )
665
+ mask = mask.to(device=device, dtype=dtype)
666
+
667
+ masked_image = masked_image.to(device=device, dtype=dtype)
668
+
669
+ if masked_image.shape[1] == 4:
670
+ masked_image_latents = masked_image
671
+ else:
672
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
673
+
674
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
675
+ if mask.shape[0] < batch_size:
676
+ if not batch_size % mask.shape[0] == 0:
677
+ raise ValueError(
678
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
679
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
680
+ " of masks that you pass is divisible by the total requested batch size."
681
+ )
682
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
683
+ if masked_image_latents.shape[0] < batch_size:
684
+ if not batch_size % masked_image_latents.shape[0] == 0:
685
+ raise ValueError(
686
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
687
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
688
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
689
+ )
690
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
691
+
692
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
693
+ masked_image_latents = (
694
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
695
+ )
696
+
697
+ # aligning device to prevent device errors when concating it with the latent model input
698
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
699
+ return mask, masked_image_latents
700
+
701
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
702
+ def get_timesteps(self, num_inference_steps, strength, device):
703
+ # get the original timestep using init_timestep
704
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
705
+
706
+ t_start = max(num_inference_steps - init_timestep, 0)
707
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
708
+
709
+ return timesteps, num_inference_steps - t_start
710
+
711
+ @torch.no_grad()
712
+ def __call__(
713
+ self,
714
+ prompt: Union[str, List[str]] = None,
715
+ image: PipelineImageInput = None,
716
+ mask_image: PipelineImageInput = None,
717
+ masked_image_latents: torch.FloatTensor = None,
718
+ height: Optional[int] = None,
719
+ width: Optional[int] = None,
720
+ strength: float = 1.0,
721
+ num_inference_steps: int = 50,
722
+ guidance_scale: float = 7.5,
723
+ negative_prompt: Optional[Union[str, List[str]]] = None,
724
+ num_images_per_prompt: Optional[int] = 1,
725
+ eta: float = 0.0,
726
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
727
+ latents: Optional[torch.FloatTensor] = None,
728
+ prompt_embeds: Optional[torch.FloatTensor] = None,
729
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
730
+ output_type: Optional[str] = "pil",
731
+ return_dict: bool = True,
732
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
733
+ callback_steps: int = 1,
734
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
735
+ clip_skip: int = None,
736
+ ):
737
+ r"""
738
+ The call function to the pipeline for generation.
739
+
740
+ Args:
741
+ prompt (`str` or `List[str]`, *optional*):
742
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
743
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
744
+ `Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to
745
+ be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch
746
+ tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the
747
+ expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the
748
+ expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but
749
+ if passing latents directly it is not encoded again.
750
+ mask_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
751
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
752
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
753
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
754
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
755
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
756
+ 1)`, or `(H, W)`.
757
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
758
+ The height in pixels of the generated image.
759
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
760
+ The width in pixels of the generated image.
761
+ strength (`float`, *optional*, defaults to 1.0):
762
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
763
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
764
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
765
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
766
+ essentially ignores `image`.
767
+ num_inference_steps (`int`, *optional*, defaults to 50):
768
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
769
+ expense of slower inference. This parameter is modulated by `strength`.
770
+ guidance_scale (`float`, *optional*, defaults to 7.5):
771
+ A higher guidance scale value encourages the model to generate images closely linked to the text
772
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
773
+ negative_prompt (`str` or `List[str]`, *optional*):
774
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
775
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
776
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
777
+ The number of images to generate per prompt.
778
+ eta (`float`, *optional*, defaults to 0.0):
779
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
780
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
781
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
782
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
783
+ generation deterministic.
784
+ latents (`torch.FloatTensor`, *optional*):
785
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
786
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
787
+ tensor is generated by sampling using the supplied random `generator`.
788
+ prompt_embeds (`torch.FloatTensor`, *optional*):
789
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
790
+ provided, text embeddings are generated from the `prompt` input argument.
791
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
792
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
793
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
794
+ output_type (`str`, *optional*, defaults to `"pil"`):
795
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
796
+ return_dict (`bool`, *optional*, defaults to `True`):
797
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
798
+ plain tuple.
799
+ callback (`Callable`, *optional*):
800
+ A function that calls every `callback_steps` steps during inference. The function is called with the
801
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
802
+ callback_steps (`int`, *optional*, defaults to 1):
803
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
804
+ every step.
805
+ cross_attention_kwargs (`dict`, *optional*):
806
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
807
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
808
+ clip_skip (`int`, *optional*):
809
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
810
+ the output of the pre-final layer will be used for computing the prompt embeddings.
811
+ Examples:
812
+
813
+ ```py
814
+ >>> import PIL
815
+ >>> import requests
816
+ >>> import torch
817
+ >>> from io import BytesIO
818
+
819
+ >>> from diffusers import StableDiffusionInpaintPipeline
820
+
821
+
822
+ >>> def download_image(url):
823
+ ... response = requests.get(url)
824
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
825
+
826
+
827
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
828
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
829
+
830
+ >>> init_image = download_image(img_url).resize((512, 512))
831
+ >>> mask_image = download_image(mask_url).resize((512, 512))
832
+
833
+ >>> pipe = StableDiffusionInpaintPipeline.from_pretrained(
834
+ ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
835
+ ... )
836
+ >>> pipe = pipe.to("cuda")
837
+
838
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
839
+ >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
840
+ ```
841
+
842
+ Returns:
843
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
844
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
845
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
846
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
847
+ "not-safe-for-work" (nsfw) content.
848
+ """
849
+ # 0. Default height and width to unet
850
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
851
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
852
+
853
+ # 1. Check inputs
854
+ self.check_inputs(
855
+ prompt,
856
+ height,
857
+ width,
858
+ strength,
859
+ callback_steps,
860
+ negative_prompt,
861
+ prompt_embeds,
862
+ negative_prompt_embeds,
863
+ )
864
+
865
+ # 2. Define call parameters
866
+ if prompt is not None and isinstance(prompt, str):
867
+ batch_size = 1
868
+ elif prompt is not None and isinstance(prompt, list):
869
+ batch_size = len(prompt)
870
+ else:
871
+ batch_size = prompt_embeds.shape[0]
872
+
873
+ device = self._execution_device
874
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
875
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
876
+ # corresponds to doing no classifier free guidance.
877
+ do_classifier_free_guidance = guidance_scale > 1.0
878
+
879
+ # 3. Encode input prompt
880
+ text_encoder_lora_scale = (
881
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
882
+ )
883
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
884
+ prompt,
885
+ device,
886
+ num_images_per_prompt,
887
+ do_classifier_free_guidance,
888
+ negative_prompt,
889
+ prompt_embeds=prompt_embeds,
890
+ negative_prompt_embeds=negative_prompt_embeds,
891
+ lora_scale=text_encoder_lora_scale,
892
+ clip_skip=clip_skip,
893
+ )
894
+ # For classifier free guidance, we need to do two forward passes.
895
+ # Here we concatenate the unconditional and text embeddings into a single batch
896
+ # to avoid doing two forward passes
897
+ if do_classifier_free_guidance:
898
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
899
+
900
+ # 4. set timesteps
901
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
902
+ timesteps, num_inference_steps = self.get_timesteps(
903
+ num_inference_steps=num_inference_steps, strength=strength, device=device
904
+ )
905
+ # check that number of inference steps is not < 1 - as this doesn't make sense
906
+ if num_inference_steps < 1:
907
+ raise ValueError(
908
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
909
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
910
+ )
911
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
912
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
913
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
914
+ is_strength_max = strength == 1.0
915
+
916
+ # 5. Preprocess mask and image
917
+
918
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
919
+ init_image = init_image.to(dtype=torch.float32)
920
+
921
+ # 6. Prepare latent variables
922
+ num_channels_latents = self.vae.config.latent_channels
923
+ num_channels_unet = self.unet.config.in_channels
924
+ return_image_latents = num_channels_unet == 4
925
+
926
+ latents_outputs = self.prepare_latents(
927
+ batch_size * num_images_per_prompt,
928
+ num_channels_latents,
929
+ height,
930
+ width,
931
+ prompt_embeds.dtype,
932
+ device,
933
+ generator,
934
+ latents,
935
+ image=init_image,
936
+ timestep=latent_timestep,
937
+ is_strength_max=is_strength_max,
938
+ return_noise=True,
939
+ return_image_latents=return_image_latents,
940
+ )
941
+
942
+ if return_image_latents:
943
+ latents, noise, image_latents = latents_outputs
944
+ else:
945
+ latents, noise = latents_outputs
946
+
947
+ # 7. Prepare mask latent variables
948
+ mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width)
949
+
950
+ if masked_image_latents is None:
951
+ masked_image = init_image * (mask_condition < 0.5)
952
+ else:
953
+ masked_image = masked_image_latents
954
+
955
+ mask, masked_image_latents = self.prepare_mask_latents(
956
+ mask_condition,
957
+ masked_image,
958
+ batch_size * num_images_per_prompt,
959
+ height,
960
+ width,
961
+ prompt_embeds.dtype,
962
+ device,
963
+ generator,
964
+ do_classifier_free_guidance,
965
+ )
966
+
967
+ # 8. Check that sizes of mask, masked image and latents match
968
+ if num_channels_unet == 9:
969
+ # default case for runwayml/stable-diffusion-inpainting
970
+ num_channels_mask = mask.shape[1]
971
+ num_channels_masked_image = masked_image_latents.shape[1]
972
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
973
+ raise ValueError(
974
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
975
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
976
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
977
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
978
+ " `pipeline.unet` or your `mask_image` or `image` input."
979
+ )
980
+ elif num_channels_unet != 4:
981
+ raise ValueError(
982
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
983
+ )
984
+
985
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
986
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
987
+
988
+ return latents, mask, masked_image, masked_image_latents
989
+
990
+ # 10. Denoising loop
991
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
992
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
993
+ for i, t in enumerate(timesteps):
994
+ # expand the latents if we are doing classifier free guidance
995
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
996
+
997
+ # concat latents, mask, masked_image_latents in the channel dimension
998
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
999
+
1000
+ if num_channels_unet == 9:
1001
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1002
+
1003
+ # predict the noise residual
1004
+ noise_pred = self.unet(
1005
+ latent_model_input,
1006
+ t,
1007
+ encoder_hidden_states=prompt_embeds,
1008
+ cross_attention_kwargs=cross_attention_kwargs,
1009
+ return_dict=False,
1010
+ )[0]
1011
+
1012
+ # perform guidance
1013
+ if do_classifier_free_guidance:
1014
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1015
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1016
+
1017
+ # compute the previous noisy sample x_t -> x_t-1
1018
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1019
+ if num_channels_unet == 4:
1020
+ init_latents_proper = image_latents
1021
+ if do_classifier_free_guidance:
1022
+ init_mask, _ = mask.chunk(2)
1023
+ else:
1024
+ init_mask = mask
1025
+
1026
+ if i < len(timesteps) - 1:
1027
+ noise_timestep = timesteps[i + 1]
1028
+ init_latents_proper = self.scheduler.add_noise(
1029
+ init_latents_proper, noise, torch.tensor([noise_timestep])
1030
+ )
1031
+
1032
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1033
+
1034
+ # call the callback, if provided
1035
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1036
+ progress_bar.update()
1037
+ if callback is not None and i % callback_steps == 0:
1038
+ step_idx = i // getattr(self.scheduler, "order", 1)
1039
+ callback(step_idx, t, latents)
1040
+
1041
+ if not output_type == "latent":
1042
+ condition_kwargs = {}
1043
+ if isinstance(self.vae, AsymmetricAutoencoderKL):
1044
+ init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)
1045
+ init_image_condition = init_image.clone()
1046
+ init_image = self._encode_vae_image(init_image, generator=generator)
1047
+ mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)
1048
+ condition_kwargs = {"image": init_image_condition, "mask": mask_condition}
1049
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **condition_kwargs)[0]
1050
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1051
+ else:
1052
+ image = latents
1053
+ has_nsfw_concept = None
1054
+
1055
+ if has_nsfw_concept is None:
1056
+ do_denormalize = [True] * image.shape[0]
1057
+ else:
1058
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1059
+
1060
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1061
+
1062
+ # Offload all models
1063
+ self.maybe_free_model_hooks()
1064
+
1065
+ if not return_dict:
1066
+ return (image, has_nsfw_concept)
1067
+
1068
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
marigold/util/__pycache__/batchsize.cpython-310.pyc ADDED
Binary file (1.75 kB). View file
 
marigold/util/__pycache__/ensemble.cpython-310.pyc ADDED
Binary file (6.52 kB). View file
 
marigold/util/__pycache__/image_util.cpython-310.pyc ADDED
Binary file (2.82 kB). View file
 
marigold/util/batchsize.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
+ import torch
22
+ import math
23
+
24
+
25
+ # Search table for suggested max. inference batch size
26
+ bs_search_table = [
27
+ # tested on A100-PCIE-80GB
28
+ {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
29
+ {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
30
+ # tested on A100-PCIE-40GB
31
+ {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
32
+ {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
33
+ {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
34
+ {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
35
+ # tested on RTX3090, RTX4090
36
+ {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
37
+ {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
38
+ {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
39
+ {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
40
+ {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
41
+ {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
42
+ # tested on GTX1080Ti
43
+ {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
44
+ {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
45
+ {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
46
+ {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
47
+ {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
48
+ ]
49
+
50
+
51
+ def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
52
+ """
53
+ Automatically search for suitable operating batch size.
54
+
55
+ Args:
56
+ ensemble_size (`int`):
57
+ Number of predictions to be ensembled.
58
+ input_res (`int`):
59
+ Operating resolution of the input image.
60
+
61
+ Returns:
62
+ `int`: Operating batch size.
63
+ """
64
+ if not torch.cuda.is_available():
65
+ return 1
66
+
67
+ total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
68
+ filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
69
+ for settings in sorted(
70
+ filtered_bs_search_table,
71
+ key=lambda k: (k["res"], -k["total_vram"]),
72
+ ):
73
+ if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
74
+ bs = settings["bs"]
75
+ if bs > ensemble_size:
76
+ bs = ensemble_size
77
+ elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
78
+ bs = math.ceil(ensemble_size / 2)
79
+ return bs
80
+
81
+ return 1
marigold/util/ensemble.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
+ from functools import partial
22
+ from typing import Optional, Tuple
23
+
24
+ import numpy as np
25
+ import torch
26
+
27
+ from .image_util import get_tv_resample_method, resize_max_res
28
+
29
+
30
+ def inter_distances(tensors: torch.Tensor):
31
+ """
32
+ To calculate the distance between each two depth maps.
33
+ """
34
+ distances = []
35
+ for i, j in torch.combinations(torch.arange(tensors.shape[0])):
36
+ arr1 = tensors[i : i + 1]
37
+ arr2 = tensors[j : j + 1]
38
+ distances.append(arr1 - arr2)
39
+ dist = torch.concatenate(distances, dim=0)
40
+ return dist
41
+
42
+
43
+ def ensemble_depth(
44
+ depth: torch.Tensor,
45
+ scale_invariant: bool = True,
46
+ shift_invariant: bool = True,
47
+ output_uncertainty: bool = False,
48
+ reduction: str = "median",
49
+ regularizer_strength: float = 0.02,
50
+ max_iter: int = 2,
51
+ tol: float = 1e-3,
52
+ max_res: int = 1024,
53
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
54
+ """
55
+ Ensembles depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the
56
+ number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for
57
+ depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The
58
+ alignment happens when the predictions have one or more degrees of freedom, that is when they are either
59
+ affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only
60
+ `scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`)
61
+ alignment is skipped and only ensembling is performed.
62
+
63
+ Args:
64
+ depth (`torch.Tensor`):
65
+ Input ensemble depth maps.
66
+ scale_invariant (`bool`, *optional*, defaults to `True`):
67
+ Whether to treat predictions as scale-invariant.
68
+ shift_invariant (`bool`, *optional*, defaults to `True`):
69
+ Whether to treat predictions as shift-invariant.
70
+ output_uncertainty (`bool`, *optional*, defaults to `False`):
71
+ Whether to output uncertainty map.
72
+ reduction (`str`, *optional*, defaults to `"median"`):
73
+ Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and
74
+ `"median"`.
75
+ regularizer_strength (`float`, *optional*, defaults to `0.02`):
76
+ Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1.
77
+ max_iter (`int`, *optional*, defaults to `2`):
78
+ Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options`
79
+ argument.
80
+ tol (`float`, *optional*, defaults to `1e-3`):
81
+ Alignment solver tolerance. The solver stops when the tolerance is reached.
82
+ max_res (`int`, *optional*, defaults to `1024`):
83
+ Resolution at which the alignment is performed; `None` matches the `processing_resolution`.
84
+ Returns:
85
+ A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape:
86
+ `(1, 1, H, W)`.
87
+ """
88
+ if depth.dim() != 4 or depth.shape[1] != 1:
89
+ raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.")
90
+ if reduction not in ("mean", "median"):
91
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
92
+ if not scale_invariant and shift_invariant:
93
+ raise ValueError("Pure shift-invariant ensembling is not supported.")
94
+
95
+ def init_param(depth: torch.Tensor):
96
+ init_min = depth.reshape(ensemble_size, -1).min(dim=1).values
97
+ init_max = depth.reshape(ensemble_size, -1).max(dim=1).values
98
+
99
+ if scale_invariant and shift_invariant:
100
+ init_s = 1.0 / (init_max - init_min).clamp(min=1e-6)
101
+ init_t = -init_s * init_min
102
+ param = torch.cat((init_s, init_t)).cpu().numpy()
103
+ elif scale_invariant:
104
+ init_s = 1.0 / init_max.clamp(min=1e-6)
105
+ param = init_s.cpu().numpy()
106
+ else:
107
+ raise ValueError("Unrecognized alignment.")
108
+
109
+ return param
110
+
111
+ def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor:
112
+ if scale_invariant and shift_invariant:
113
+ s, t = np.split(param, 2)
114
+ s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1)
115
+ t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1)
116
+ out = depth * s + t
117
+ elif scale_invariant:
118
+ s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1)
119
+ out = depth * s
120
+ else:
121
+ raise ValueError("Unrecognized alignment.")
122
+ return out
123
+
124
+ def ensemble(
125
+ depth_aligned: torch.Tensor, return_uncertainty: bool = False
126
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
127
+ uncertainty = None
128
+ if reduction == "mean":
129
+ prediction = torch.mean(depth_aligned, dim=0, keepdim=True)
130
+ if return_uncertainty:
131
+ uncertainty = torch.std(depth_aligned, dim=0, keepdim=True)
132
+ elif reduction == "median":
133
+ prediction = torch.median(depth_aligned, dim=0, keepdim=True).values
134
+ if return_uncertainty:
135
+ uncertainty = torch.median(
136
+ torch.abs(depth_aligned - prediction), dim=0, keepdim=True
137
+ ).values
138
+ else:
139
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
140
+ return prediction, uncertainty
141
+
142
+ def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float:
143
+ cost = 0.0
144
+ depth_aligned = align(depth, param)
145
+
146
+ for i, j in torch.combinations(torch.arange(ensemble_size)):
147
+ diff = depth_aligned[i] - depth_aligned[j]
148
+ cost += (diff**2).mean().sqrt().item()
149
+
150
+ if regularizer_strength > 0:
151
+ prediction, _ = ensemble(depth_aligned, return_uncertainty=False)
152
+ err_near = (0.0 - prediction.min()).abs().item()
153
+ err_far = (1.0 - prediction.max()).abs().item()
154
+ cost += (err_near + err_far) * regularizer_strength
155
+
156
+ return cost
157
+
158
+ def compute_param(depth: torch.Tensor):
159
+ import scipy
160
+
161
+ depth_to_align = depth.to(torch.float32)
162
+ if max_res is not None and max(depth_to_align.shape[2:]) > max_res:
163
+ depth_to_align = resize_max_res(
164
+ depth_to_align, max_res, get_tv_resample_method("nearest-exact")
165
+ )
166
+
167
+ param = init_param(depth_to_align)
168
+
169
+ res = scipy.optimize.minimize(
170
+ partial(cost_fn, depth=depth_to_align),
171
+ param,
172
+ method="BFGS",
173
+ tol=tol,
174
+ options={"maxiter": max_iter, "disp": False},
175
+ )
176
+
177
+ return res.x
178
+
179
+ requires_aligning = scale_invariant or shift_invariant
180
+ ensemble_size = depth.shape[0]
181
+
182
+ if requires_aligning:
183
+ param = compute_param(depth)
184
+ depth = align(depth, param)
185
+
186
+ depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty)
187
+
188
+ depth_max = depth.max()
189
+ if scale_invariant and shift_invariant:
190
+ depth_min = depth.min()
191
+ elif scale_invariant:
192
+ depth_min = 0
193
+ else:
194
+ raise ValueError("Unrecognized alignment.")
195
+ depth_range = (depth_max - depth_min).clamp(min=1e-6)
196
+ depth = (depth - depth_min) / depth_range
197
+ if output_uncertainty:
198
+ uncertainty /= depth_range
199
+
200
+ return depth, uncertainty # [1,1,H,W], [1,1,H,W]
marigold/util/image_util.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ # Last modified: 2024-05-24
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # --------------------------------------------------------------------------
16
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
17
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
18
+ # More information about the method can be found at https://marigoldmonodepth.github.io
19
+ # --------------------------------------------------------------------------
20
+
21
+
22
+ import matplotlib
23
+ import numpy as np
24
+ import torch
25
+ from torchvision.transforms import InterpolationMode
26
+ from torchvision.transforms.functional import resize
27
+
28
+
29
+ def colorize_depth_maps(
30
+ depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
31
+ ):
32
+ """
33
+ Colorize depth maps.
34
+ """
35
+ assert len(depth_map.shape) >= 2, "Invalid dimension"
36
+
37
+ if isinstance(depth_map, torch.Tensor):
38
+ depth = depth_map.detach().squeeze().numpy()
39
+ elif isinstance(depth_map, np.ndarray):
40
+ depth = depth_map.copy().squeeze()
41
+ # reshape to [ (B,) H, W ]
42
+ if depth.ndim < 3:
43
+ depth = depth[np.newaxis, :, :]
44
+
45
+ # colorize
46
+ cm = matplotlib.colormaps[cmap]
47
+ depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
48
+ img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
49
+ img_colored_np = np.rollaxis(img_colored_np, 3, 1)
50
+
51
+ if valid_mask is not None:
52
+ if isinstance(depth_map, torch.Tensor):
53
+ valid_mask = valid_mask.detach().numpy()
54
+ valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
55
+ if valid_mask.ndim < 3:
56
+ valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
57
+ else:
58
+ valid_mask = valid_mask[:, np.newaxis, :, :]
59
+ valid_mask = np.repeat(valid_mask, 3, axis=1)
60
+ img_colored_np[~valid_mask] = 0
61
+
62
+ if isinstance(depth_map, torch.Tensor):
63
+ img_colored = torch.from_numpy(img_colored_np).float()
64
+ elif isinstance(depth_map, np.ndarray):
65
+ img_colored = img_colored_np
66
+
67
+ return img_colored
68
+
69
+
70
+ def chw2hwc(chw):
71
+ assert 3 == len(chw.shape)
72
+ if isinstance(chw, torch.Tensor):
73
+ hwc = torch.permute(chw, (1, 2, 0))
74
+ elif isinstance(chw, np.ndarray):
75
+ hwc = np.moveaxis(chw, 0, -1)
76
+ return hwc
77
+
78
+ def resize_max_res(
79
+ img: torch.Tensor,
80
+ max_edge_resolution: int,
81
+ resample_method: InterpolationMode = InterpolationMode.BILINEAR,
82
+ ) -> torch.Tensor:
83
+ """
84
+ Resize image to limit maximum edge length while keeping aspect ratio.
85
+
86
+ Args:
87
+ img (`torch.Tensor`):
88
+ Image tensor to be resized. Expected shape: [B, C, H, W]
89
+ max_edge_resolution (`int`):
90
+ Maximum edge length (pixel).
91
+ resample_method (`PIL.Image.Resampling`):
92
+ Resampling method used to resize images.
93
+
94
+ Returns:
95
+ `torch.Tensor`: Resized image.
96
+ """
97
+ assert 4 == img.dim(), f"Invalid input shape {img.shape}"
98
+
99
+ original_height, original_width = img.shape[-2:]
100
+ downscale_factor = min(
101
+ max_edge_resolution / original_width, max_edge_resolution / original_height
102
+ )
103
+
104
+ new_width = int(original_width * downscale_factor)
105
+ new_height = int(original_height * downscale_factor)
106
+
107
+ resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
108
+ return resized_img
109
+
110
+
111
+ def get_tv_resample_method(method_str: str) -> InterpolationMode:
112
+ resample_method_dict = {
113
+ "bilinear": InterpolationMode.BILINEAR,
114
+ "bicubic": InterpolationMode.BICUBIC,
115
+ "nearest": InterpolationMode.NEAREST_EXACT,
116
+ "nearest-exact": InterpolationMode.NEAREST_EXACT,
117
+ }
118
+ resample_method = resample_method_dict.get(method_str, None)
119
+ if resample_method is None:
120
+ raise ValueError(f"Unknown resampling method: {resample_method}")
121
+ else:
122
+ return resample_method