GonzaloMG commited on
Commit
ce2c7c0
1 Parent(s): 30e96fe

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +466 -0
pipeline.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved.
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # --------------------------------------------------------------------------
16
+ # More information and citation instructions are available on the
17
+ # Marigold project website: https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+ # @GonzaloMartinGarcia
21
+ # Inference Pipeline for End-to-End Marigold and Stable Diffusion Depth Estimators
22
+ # ----------------------------------------------------------------------------------
23
+ # A streamlined version of the official MarigoldDepthPipeline from diffusers:
24
+ # https://github.com/huggingface/diffusers/blob/a98a839de75f1ad82d8d200c3bc2e4ff89929081/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py#L96
25
+ #
26
+ # This implementation is meant for use with the diffusers custom_pipeline feature.
27
+ # Modifications from the original code are marked with '# add' comments.
28
+
29
+ from dataclasses import dataclass
30
+ from typing import List, Optional, Tuple, Union
31
+
32
+ import numpy as np
33
+ import torch
34
+ from PIL import Image
35
+ from tqdm.auto import tqdm
36
+ from transformers import CLIPTextModel, CLIPTokenizer
37
+
38
+ from diffusers.image_processor import PipelineImageInput
39
+ from diffusers.models import (
40
+ AutoencoderKL,
41
+ UNet2DConditionModel,
42
+ )
43
+ from diffusers.schedulers import (
44
+ DDIMScheduler,
45
+ )
46
+ from diffusers.utils import (
47
+ BaseOutput,
48
+ logging,
49
+ )
50
+ from diffusers import DiffusionPipeline
51
+ from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor
52
+
53
+ # add
54
+ def zeros_tensor(
55
+ shape: Union[Tuple, List],
56
+ device: Optional["torch.device"] = None,
57
+ dtype: Optional["torch.dtype"] = None,
58
+ layout: Optional["torch.layout"] = None,
59
+ ):
60
+ """
61
+ A helper function to create tensors of zeros on the desired `device`.
62
+ Mirrors randn_tensor from diffusers.utils.torch_utils.
63
+ """
64
+ layout = layout or torch.strided
65
+ device = device or torch.device("cpu")
66
+ latents = torch.zeros(list(shape), dtype=dtype, layout=layout).to(device)
67
+ return latents
68
+
69
+
70
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
71
+
72
+ @dataclass
73
+ class E2EMarigoldDepthOutput(BaseOutput):
74
+ """
75
+ Output class for Marigold monocular depth prediction pipeline.
76
+
77
+ Args:
78
+ prediction (`np.ndarray`, `torch.Tensor`):
79
+ Predicted depth maps with values in the range [0, 1]. The shape is always $numimages \times 1 \times height
80
+ \times width$, regardless of whether the images were passed as a 4D array or a list.
81
+ latent (`None`, `torch.Tensor`):
82
+ Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
83
+ The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
84
+ """
85
+
86
+ prediction: Union[np.ndarray, torch.Tensor]
87
+ latent: Union[None, torch.Tensor]
88
+
89
+
90
+ class E2EMarigoldDepthPipeline(DiffusionPipeline):
91
+ """
92
+ # add
93
+ Pipeline for monocular depth estimation using the E2E FT Marigold and SD method: https://gonzalomartingarcia.github.io/diffusion-e2e-ft/
94
+ Implementation is built upon Marigold: https://marigoldmonodepth.github.io
95
+
96
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
97
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
98
+
99
+ Args:
100
+ unet (`UNet2DConditionModel`):
101
+ Conditional U-Net to denoise the depth latent, conditioned on image latent.
102
+ vae (`AutoencoderKL`):
103
+ Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent
104
+ representations.
105
+ scheduler (`DDIMScheduler`):
106
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
107
+ text_encoder (`CLIPTextModel`):
108
+ Text-encoder, for empty text embedding.
109
+ tokenizer (`CLIPTokenizer`):
110
+ CLIP tokenizer.
111
+ default_processing_resolution (`int`, *optional*):
112
+ The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
113
+ the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
114
+ default value is used. This is required to ensure reasonable results with various model flavors trained
115
+ with varying optimal processing resolution values.
116
+ """
117
+
118
+ model_cpu_offload_seq = "text_encoder->unet->vae"
119
+
120
+ def __init__(
121
+ self,
122
+ unet: UNet2DConditionModel,
123
+ vae: AutoencoderKL,
124
+ scheduler: Union[DDIMScheduler],
125
+ text_encoder: CLIPTextModel,
126
+ tokenizer: CLIPTokenizer,
127
+ default_processing_resolution: Optional[int] = 768, # add
128
+ ):
129
+ super().__init__()
130
+
131
+ self.register_modules(
132
+ unet=unet,
133
+ vae=vae,
134
+ scheduler=scheduler,
135
+ text_encoder=text_encoder,
136
+ tokenizer=tokenizer,
137
+ )
138
+ self.register_to_config(
139
+ default_processing_resolution=default_processing_resolution,
140
+ )
141
+
142
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
143
+ self.default_processing_resolution = default_processing_resolution
144
+ self.empty_text_embedding = None
145
+
146
+ self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
147
+
148
+ def check_inputs(
149
+ self,
150
+ image: PipelineImageInput,
151
+ processing_resolution: int,
152
+ resample_method_input: str,
153
+ resample_method_output: str,
154
+ batch_size: int,
155
+ output_type: str,
156
+ ) -> int:
157
+ if processing_resolution is None:
158
+ raise ValueError(
159
+ "`processing_resolution` is not specified and could not be resolved from the model config."
160
+ )
161
+ if processing_resolution < 0:
162
+ raise ValueError(
163
+ "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for "
164
+ "downsampled processing."
165
+ )
166
+ if processing_resolution % self.vae_scale_factor != 0:
167
+ raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.")
168
+ if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
169
+ raise ValueError(
170
+ "`resample_method_input` takes string values compatible with PIL library: "
171
+ "nearest, nearest-exact, bilinear, bicubic, area."
172
+ )
173
+ if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
174
+ raise ValueError(
175
+ "`resample_method_output` takes string values compatible with PIL library: "
176
+ "nearest, nearest-exact, bilinear, bicubic, area."
177
+ )
178
+ if batch_size < 1:
179
+ raise ValueError("`batch_size` must be positive.")
180
+ if output_type not in ["pt", "np"]:
181
+ raise ValueError("`output_type` must be one of `pt` or `np`.")
182
+
183
+ # image checks
184
+ num_images = 0
185
+ W, H = None, None
186
+ if not isinstance(image, list):
187
+ image = [image]
188
+ for i, img in enumerate(image):
189
+ if isinstance(img, np.ndarray) or torch.is_tensor(img):
190
+ if img.ndim not in (2, 3, 4):
191
+ raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.")
192
+ H_i, W_i = img.shape[-2:]
193
+ N_i = 1
194
+ if img.ndim == 4:
195
+ N_i = img.shape[0]
196
+ elif isinstance(img, Image.Image):
197
+ W_i, H_i = img.size
198
+ N_i = 1
199
+ else:
200
+ raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.")
201
+ if W is None:
202
+ W, H = W_i, H_i
203
+ elif (W, H) != (W_i, H_i):
204
+ raise ValueError(
205
+ f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}"
206
+ )
207
+ num_images += N_i
208
+
209
+ return num_images
210
+
211
+ def progress_bar(self, iterable=None, total=None, desc=None, leave=True):
212
+ if not hasattr(self, "_progress_bar_config"):
213
+ self._progress_bar_config = {}
214
+ elif not isinstance(self._progress_bar_config, dict):
215
+ raise ValueError(
216
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
217
+ )
218
+
219
+ progress_bar_config = dict(**self._progress_bar_config)
220
+ progress_bar_config["desc"] = progress_bar_config.get("desc", desc)
221
+ progress_bar_config["leave"] = progress_bar_config.get("leave", leave)
222
+ if iterable is not None:
223
+ return tqdm(iterable, **progress_bar_config)
224
+ elif total is not None:
225
+ return tqdm(total=total, **progress_bar_config)
226
+ else:
227
+ raise ValueError("Either `total` or `iterable` has to be defined.")
228
+
229
+ @torch.no_grad()
230
+ def __call__(
231
+ self,
232
+ image: PipelineImageInput,
233
+ processing_resolution: Optional[int] = None,
234
+ match_input_resolution: bool = True,
235
+ resample_method_input: str = "bilinear",
236
+ resample_method_output: str = "bilinear",
237
+ batch_size: int = 1,
238
+ output_type: str = "np",
239
+ output_latent: bool = False,
240
+ return_dict: bool = True,
241
+ ):
242
+ """
243
+ Function invoked when calling the pipeline.
244
+
245
+ Args:
246
+ image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`),
247
+ `List[torch.Tensor]`: An input image or images used as an input for the depth estimation task. For
248
+ arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible
249
+ by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or
250
+ three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the
251
+ same width and height.
252
+ processing_resolution (`int`, *optional*, defaults to `None`):
253
+ Effective processing resolution. When set to `0`, matches the larger input image dimension. This
254
+ produces crisper predictions, but may also lead to the overall loss of global context. The default
255
+ value `None` resolves to the optimal value from the model config.
256
+ match_input_resolution (`bool`, *optional*, defaults to `True`):
257
+ When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer
258
+ side of the output will equal to `processing_resolution`.
259
+ resample_method_input (`str`, *optional*, defaults to `"bilinear"`):
260
+ Resampling method used to resize input images to `processing_resolution`. The accepted values are:
261
+ `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
262
+ resample_method_output (`str`, *optional*, defaults to `"bilinear"`):
263
+ Resampling method used to resize output predictions to match the input resolution. The accepted values
264
+ are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
265
+ batch_size (`int`, *optional*, defaults to `1`):
266
+ Batch size; only matters passing a tensor of images.
267
+ output_type (`str`, *optional*, defaults to `"np"`):
268
+ Preferred format of the output's `prediction`. The accepted ßvalues are: `"np"` (numpy array) or `"pt"` (torch tensor).
269
+ output_latent (`bool`, *optional*, defaults to `False`):
270
+ When enabled, the output's `latent` field contains the latent codes corresponding to the predictions
271
+ within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
272
+ `latents` argument.
273
+ return_dict (`bool`, *optional*, defaults to `True`):
274
+ Whether or not to return a [`~pipelines.marigold.E2EMarigoldDepthOutput`] instead of a plain tuple.
275
+
276
+ # add
277
+ E2E FT models are deterministic single step models involving no ensembling, i.e. E=1.
278
+ """
279
+
280
+ # 0. Resolving variables.
281
+ device = self._execution_device
282
+ dtype = self.dtype
283
+
284
+ # Model-specific optimal default values leading to fast and reasonable results.
285
+ if processing_resolution is None:
286
+ processing_resolution = self.default_processing_resolution
287
+
288
+ # 1. Check inputs.
289
+ num_images = self.check_inputs(
290
+ image,
291
+ processing_resolution,
292
+ resample_method_input,
293
+ resample_method_output,
294
+ batch_size,
295
+ output_type,
296
+ )
297
+
298
+ # 2. Prepare empty text conditioning.
299
+ # Model invocation: self.tokenizer, self.text_encoder.
300
+ if self.empty_text_embedding is None:
301
+ prompt = ""
302
+ text_inputs = self.tokenizer(
303
+ prompt,
304
+ padding="do_not_pad",
305
+ max_length=self.tokenizer.model_max_length,
306
+ truncation=True,
307
+ return_tensors="pt",
308
+ )
309
+ text_input_ids = text_inputs.input_ids.to(device)
310
+ self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024]
311
+
312
+ # 3. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`,
313
+ # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where
314
+ # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are
315
+ # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None`
316
+ # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of
317
+ # operation and leads to the most reasonable results. Using the native image resolution or any other processing
318
+ # resolution can lead to loss of either fine details or global context in the output predictions.
319
+ image, padding, original_resolution = self.image_processor.preprocess(
320
+ image, processing_resolution, resample_method_input, device, dtype
321
+ ) # [N,3,PPH,PPW]
322
+
323
+ # 4. Encode input image into latent space. At this step, each of the `N` input images is represented with `E`
324
+ # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently.
325
+ # Latents of each such predictions across all input images and all ensemble members are represented in the
326
+ # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded
327
+ # into latent space and replicated `E` times. Encoding into latent space happens in batches of size `batch_size`.
328
+ # Model invocation: self.vae.encoder.
329
+ image_latent, pred_latent = self.prepare_latents(
330
+ image, batch_size
331
+ ) # [N*E,4,h,w], [N*E,4,h,w]
332
+
333
+ del image
334
+
335
+ batch_empty_text_embedding = self.empty_text_embedding.to(device=device, dtype=dtype).repeat(
336
+ batch_size, 1, 1
337
+ ) # [B,1024,2]
338
+
339
+ # 5. Process the denoising loop. All `N * E` latents are processed sequentially in batches of size `batch_size`.
340
+ # The unet model takes concatenated latent spaces of the input image and the predicted modality as an input, and
341
+ # outputs noise for the predicted modality's latent space.
342
+ # Model invocation: self.unet.
343
+ pred_latents = []
344
+
345
+ for i in self.progress_bar(
346
+ range(0, num_images, batch_size), leave=True, desc="E2E FT predictions..."
347
+ ):
348
+ batch_image_latent = image_latent[i : i + batch_size] # [B,4,h,w]
349
+ batch_pred_latent = pred_latent[i : i + batch_size] # [B,4,h,w]
350
+ effective_batch_size = batch_image_latent.shape[0]
351
+ text = batch_empty_text_embedding[:effective_batch_size] # [B,2,1024]
352
+
353
+ # add
354
+ # Single step inference for E2E FT models
355
+ self.scheduler.set_timesteps(1, device=device)
356
+ for t in self.progress_bar(self.scheduler.timesteps, leave=False, desc="Diffusion steps..."):
357
+ batch_latent = torch.cat([batch_image_latent, batch_pred_latent], dim=1) # [B,8,h,w]
358
+ noise = self.unet(batch_latent, t, encoder_hidden_states=text, return_dict=False)[0] # [B,4,h,w]
359
+ batch_pred_latent = self.scheduler.step(
360
+ noise, t, batch_pred_latent
361
+ ).pred_original_sample # [B,4,h,w], # add
362
+ # directly take pred_original_sample rather than prev_sample
363
+
364
+ pred_latents.append(batch_pred_latent)
365
+
366
+ pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w]
367
+
368
+ del (
369
+ pred_latents,
370
+ image_latent,
371
+ batch_empty_text_embedding,
372
+ batch_image_latent,
373
+ batch_pred_latent,
374
+ text,
375
+ batch_latent,
376
+ noise,
377
+ )
378
+
379
+ # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`,
380
+ # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`.
381
+ # Model invocation: self.vae.decoder.
382
+ prediction = torch.cat(
383
+ [
384
+ self.decode_prediction(pred_latent[i : i + batch_size])
385
+ for i in range(0, pred_latent.shape[0], batch_size)
386
+ ],
387
+ dim=0,
388
+ ) # [N*E,1,PPH,PPW]
389
+
390
+ if not output_latent:
391
+ pred_latent = None
392
+
393
+ # 7. Remove padding. The output shape is (PH, PW).
394
+ prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,1,PH,PW]
395
+
396
+ # 9. If `match_input_resolution` is set, the output prediction are upsampled to match the
397
+ # input resolution `(H, W)`. This step may introduce upsampling artifacts, and therefore can be disabled.
398
+ # Depending on the downstream use-case, upsampling can be also chosen based on the tolerated artifacts by
399
+ # setting the `resample_method_output` parameter (e.g., to `"nearest"`).
400
+ if match_input_resolution:
401
+ prediction = self.image_processor.resize_antialias(
402
+ prediction, original_resolution, resample_method_output, is_aa=False
403
+ ) # [N,1,H,W]
404
+
405
+ # 10. Prepare the final outputs.
406
+ if output_type == "np":
407
+ prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,1]
408
+
409
+ # 11. Offload all models
410
+ self.maybe_free_model_hooks()
411
+
412
+ if not return_dict:
413
+ return (prediction, pred_latent)
414
+
415
+ return E2EMarigoldDepthOutput(
416
+ prediction=prediction,
417
+ latent=pred_latent,
418
+ )
419
+
420
+ def prepare_latents(
421
+ self,
422
+ image: torch.Tensor,
423
+ batch_size: int,
424
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
425
+ def retrieve_latents(encoder_output):
426
+ if hasattr(encoder_output, "latent_dist"):
427
+ return encoder_output.latent_dist.mode()
428
+ elif hasattr(encoder_output, "latents"):
429
+ return encoder_output.latents
430
+ else:
431
+ raise AttributeError("Could not access latents of provided encoder_output")
432
+
433
+ image_latent = torch.cat(
434
+ [
435
+ retrieve_latents(self.vae.encode(image[i : i + batch_size]))
436
+ for i in range(0, image.shape[0], batch_size)
437
+ ],
438
+ dim=0,
439
+ ) # [N,4,h,w]
440
+ image_latent = image_latent * self.vae.config.scaling_factor # [N*E,4,h,w]
441
+
442
+ # add
443
+ # provide zeros as noised latent
444
+ pred_latent = zeros_tensor(
445
+ image_latent.shape,
446
+ device=image_latent.device,
447
+ dtype=image_latent.dtype,
448
+ ) # [N*E,4,h,w]
449
+
450
+ return image_latent, pred_latent
451
+
452
+ def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor:
453
+ if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels:
454
+ raise ValueError(
455
+ f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}."
456
+ )
457
+
458
+ prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W]
459
+
460
+ prediction = prediction.mean(dim=1, keepdim=True) # [B,1,H,W]
461
+ prediction = torch.clip(prediction, -1.0, 1.0) # [B,1,H,W]
462
+
463
+ # add
464
+ prediction = (prediction - prediction.min()) / (prediction.max() - prediction.min())
465
+
466
+ return prediction # [B,1,H,W]