mikonvergence commited on
Commit
1f54d7b
1 Parent(s): aca81a2

Create src/pipeline_stable_diffusion_controlnet_inpaint.py

Browse files
src/ControlNetInpaint/src DELETED
File without changes
src/ControlNetInpaint/src/pipeline_stable_diffusion_controlnet_inpaint.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import PIL.Image
3
+ import numpy as np
4
+
5
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import *
6
+
7
+ EXAMPLE_DOC_STRING = """
8
+ Examples:
9
+ ```py
10
+ >>> # !pip install opencv-python transformers accelerate
11
+ >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
12
+ >>> from diffusers.utils import load_image
13
+ >>> import numpy as np
14
+ >>> import torch
15
+
16
+ >>> import cv2
17
+ >>> from PIL import Image
18
+ >>> # download an image
19
+ >>> image = load_image(
20
+ ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
21
+ ... )
22
+ >>> image = np.array(image)
23
+ >>> mask_image = load_image(
24
+ ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
25
+ ... )
26
+ >>> mask_image = np.array(mask_image)
27
+ >>> # get canny image
28
+ >>> canny_image = cv2.Canny(image, 100, 200)
29
+ >>> canny_image = canny_image[:, :, None]
30
+ >>> canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
31
+ >>> canny_image = Image.fromarray(canny_image)
32
+
33
+ >>> # load control net and stable diffusion v1-5
34
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
35
+ >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
36
+ ... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
37
+ ... )
38
+
39
+ >>> # speed up diffusion process with faster scheduler and memory optimization
40
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
41
+ >>> # remove following line if xformers is not installed
42
+ >>> pipe.enable_xformers_memory_efficient_attention()
43
+
44
+ >>> pipe.enable_model_cpu_offload()
45
+
46
+ >>> # generate image
47
+ >>> generator = torch.manual_seed(0)
48
+ >>> image = pipe(
49
+ ... "futuristic-looking doggo",
50
+ ... num_inference_steps=20,
51
+ ... generator=generator,
52
+ ... image=image,
53
+ ... control_image=canny_image,
54
+ ... mask_image=mask_image
55
+ ... ).images[0]
56
+ ```
57
+ """
58
+
59
+
60
+ def prepare_mask_and_masked_image(image, mask):
61
+ """
62
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
63
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
64
+ ``image`` and ``1`` for the ``mask``.
65
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
66
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
67
+ Args:
68
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
69
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
70
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
71
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
72
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
73
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
74
+ Raises:
75
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
76
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
77
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
78
+ (ot the other way around).
79
+ Returns:
80
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
81
+ dimensions: ``batch x channels x height x width``.
82
+ """
83
+ if isinstance(image, torch.Tensor):
84
+ if not isinstance(mask, torch.Tensor):
85
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
86
+
87
+ # Batch single image
88
+ if image.ndim == 3:
89
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
90
+ image = image.unsqueeze(0)
91
+
92
+ # Batch and add channel dim for single mask
93
+ if mask.ndim == 2:
94
+ mask = mask.unsqueeze(0).unsqueeze(0)
95
+
96
+ # Batch single mask or add channel dim
97
+ if mask.ndim == 3:
98
+ # Single batched mask, no channel dim or single mask not batched but channel dim
99
+ if mask.shape[0] == 1:
100
+ mask = mask.unsqueeze(0)
101
+
102
+ # Batched masks no channel dim
103
+ else:
104
+ mask = mask.unsqueeze(1)
105
+
106
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
107
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
108
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
109
+
110
+ # Check image is in [-1, 1]
111
+ if image.min() < -1 or image.max() > 1:
112
+ raise ValueError("Image should be in [-1, 1] range")
113
+
114
+ # Check mask is in [0, 1]
115
+ if mask.min() < 0 or mask.max() > 1:
116
+ raise ValueError("Mask should be in [0, 1] range")
117
+
118
+ # Binarize mask
119
+ mask[mask < 0.5] = 0
120
+ mask[mask >= 0.5] = 1
121
+
122
+ # Image as float32
123
+ image = image.to(dtype=torch.float32)
124
+ elif isinstance(mask, torch.Tensor):
125
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
126
+ else:
127
+ # preprocess image
128
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
129
+ image = [image]
130
+
131
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
132
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
133
+ image = np.concatenate(image, axis=0)
134
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
135
+ image = np.concatenate([i[None, :] for i in image], axis=0)
136
+
137
+ image = image.transpose(0, 3, 1, 2)
138
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
139
+
140
+ # preprocess mask
141
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
142
+ mask = [mask]
143
+
144
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
145
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
146
+ mask = mask.astype(np.float32) / 255.0
147
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
148
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
149
+
150
+ mask[mask < 0.5] = 0
151
+ mask[mask >= 0.5] = 1
152
+ mask = torch.from_numpy(mask)
153
+
154
+ masked_image = image * (mask < 0.5)
155
+
156
+ return mask, masked_image
157
+
158
+ class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline):
159
+ r"""
160
+ Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance.
161
+
162
+ This model inherits from [`StableDiffusionControlNetPipeline`]. Check the superclass documentation for the generic methods the
163
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
164
+
165
+ Args:
166
+ vae ([`AutoencoderKL`]):
167
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
168
+ text_encoder ([`CLIPTextModel`]):
169
+ Frozen text-encoder. Stable Diffusion uses the text portion of
170
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
171
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
172
+ tokenizer (`CLIPTokenizer`):
173
+ Tokenizer of class
174
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
175
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
176
+ controlnet ([`ControlNetModel`]):
177
+ Provides additional conditioning to the unet during the denoising process
178
+ scheduler ([`SchedulerMixin`]):
179
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
180
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
181
+ safety_checker ([`StableDiffusionSafetyChecker`]):
182
+ Classification module that estimates whether generated images could be considered offensive or harmful.
183
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
184
+ feature_extractor ([`CLIPFeatureExtractor`]):
185
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
186
+ """
187
+
188
+ def prepare_mask_latents(
189
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
190
+ ):
191
+ # resize the mask to latents shape as we concatenate the mask to the latents
192
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
193
+ # and half precision
194
+ mask = torch.nn.functional.interpolate(
195
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
196
+ )
197
+ mask = mask.to(device=device, dtype=dtype)
198
+
199
+ masked_image = masked_image.to(device=device, dtype=dtype)
200
+
201
+ # encode the mask image into latents space so we can concatenate it to the latents
202
+ if isinstance(generator, list):
203
+ masked_image_latents = [
204
+ self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
205
+ for i in range(batch_size)
206
+ ]
207
+ masked_image_latents = torch.cat(masked_image_latents, dim=0)
208
+ else:
209
+ masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
210
+ masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
211
+
212
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
213
+ if mask.shape[0] < batch_size:
214
+ if not batch_size % mask.shape[0] == 0:
215
+ raise ValueError(
216
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
217
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
218
+ " of masks that you pass is divisible by the total requested batch size."
219
+ )
220
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
221
+ if masked_image_latents.shape[0] < batch_size:
222
+ if not batch_size % masked_image_latents.shape[0] == 0:
223
+ raise ValueError(
224
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
225
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
226
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
227
+ )
228
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
229
+
230
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
231
+ masked_image_latents = (
232
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
233
+ )
234
+
235
+ # aligning device to prevent device errors when concating it with the latent model input
236
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
237
+ return mask, masked_image_latents
238
+
239
+ @torch.no_grad()
240
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
241
+ def __call__(
242
+ self,
243
+ prompt: Union[str, List[str]] = None,
244
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
245
+ control_image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
246
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
247
+ height: Optional[int] = None,
248
+ width: Optional[int] = None,
249
+ num_inference_steps: int = 50,
250
+ guidance_scale: float = 7.5,
251
+ negative_prompt: Optional[Union[str, List[str]]] = None,
252
+ num_images_per_prompt: Optional[int] = 1,
253
+ eta: float = 0.0,
254
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
255
+ latents: Optional[torch.FloatTensor] = None,
256
+ prompt_embeds: Optional[torch.FloatTensor] = None,
257
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
258
+ output_type: Optional[str] = "pil",
259
+ return_dict: bool = True,
260
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
261
+ callback_steps: int = 1,
262
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
263
+ controlnet_conditioning_scale: float = 1.0,
264
+ ):
265
+ r"""
266
+ Function invoked when calling the pipeline for generation.
267
+ Args:
268
+ prompt (`str` or `List[str]`, *optional*):
269
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
270
+ instead.
271
+ image (`PIL.Image.Image`):
272
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
273
+ be masked out with `mask_image` and repainted according to `prompt`.
274
+ control_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
275
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
276
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
277
+ also be accepted as an image. The control image is automatically resized to fit the output image.
278
+ mask_image (`PIL.Image.Image`):
279
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
280
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
281
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
282
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
283
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
284
+ The height in pixels of the generated image.
285
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
286
+ The width in pixels of the generated image.
287
+ num_inference_steps (`int`, *optional*, defaults to 50):
288
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
289
+ expense of slower inference.
290
+ guidance_scale (`float`, *optional*, defaults to 7.5):
291
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
292
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
293
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
294
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
295
+ usually at the expense of lower image quality.
296
+ negative_prompt (`str` or `List[str]`, *optional*):
297
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
298
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
299
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
300
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
301
+ The number of images to generate per prompt.
302
+ eta (`float`, *optional*, defaults to 0.0):
303
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
304
+ [`schedulers.DDIMScheduler`], will be ignored for others.
305
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
306
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
307
+ to make generation deterministic.
308
+ latents (`torch.FloatTensor`, *optional*):
309
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
310
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
311
+ tensor will ge generated by sampling using the supplied random `generator`.
312
+ prompt_embeds (`torch.FloatTensor`, *optional*):
313
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
314
+ provided, text embeddings will be generated from `prompt` input argument.
315
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
316
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
317
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
318
+ argument.
319
+ output_type (`str`, *optional*, defaults to `"pil"`):
320
+ The output format of the generate image. Choose between
321
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
322
+ return_dict (`bool`, *optional*, defaults to `True`):
323
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
324
+ plain tuple.
325
+ callback (`Callable`, *optional*):
326
+ A function that will be called every `callback_steps` steps during inference. The function will be
327
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
328
+ callback_steps (`int`, *optional*, defaults to 1):
329
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
330
+ called at every step.
331
+ cross_attention_kwargs (`dict`, *optional*):
332
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
333
+ `self.processor` in
334
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
335
+ controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
336
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
337
+ to the residual in the original unet.
338
+ Examples:
339
+ Returns:
340
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
341
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
342
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
343
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
344
+ (nsfw) content, according to the `safety_checker`.
345
+ """
346
+ # 0. Default height and width to unet
347
+ height, width = self._default_height_width(height, width, control_image)
348
+
349
+ # 1. Check inputs. Raise error if not correct
350
+ self.check_inputs(
351
+ prompt, control_image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
352
+ )
353
+
354
+ # 2. Define call parameters
355
+ if prompt is not None and isinstance(prompt, str):
356
+ batch_size = 1
357
+ elif prompt is not None and isinstance(prompt, list):
358
+ batch_size = len(prompt)
359
+ else:
360
+ batch_size = prompt_embeds.shape[0]
361
+
362
+ device = self._execution_device
363
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
364
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
365
+ # corresponds to doing no classifier free guidance.
366
+ do_classifier_free_guidance = guidance_scale > 1.0
367
+
368
+ # 3. Encode input prompt
369
+ prompt_embeds = self._encode_prompt(
370
+ prompt,
371
+ device,
372
+ num_images_per_prompt,
373
+ do_classifier_free_guidance,
374
+ negative_prompt,
375
+ prompt_embeds=prompt_embeds,
376
+ negative_prompt_embeds=negative_prompt_embeds,
377
+ )
378
+
379
+ # 4. Prepare image
380
+ control_image = self.prepare_image(
381
+ control_image,
382
+ width,
383
+ height,
384
+ batch_size * num_images_per_prompt,
385
+ num_images_per_prompt,
386
+ device,
387
+ self.controlnet.dtype,
388
+ )
389
+
390
+ if do_classifier_free_guidance:
391
+ control_image = torch.cat([control_image] * 2)
392
+
393
+ # 5. Prepare timesteps
394
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
395
+ timesteps = self.scheduler.timesteps
396
+
397
+ # 6. Prepare latent variables
398
+ num_channels_latents = self.controlnet.config.in_channels
399
+ latents = self.prepare_latents(
400
+ batch_size * num_images_per_prompt,
401
+ num_channels_latents,
402
+ height,
403
+ width,
404
+ prompt_embeds.dtype,
405
+ device,
406
+ generator,
407
+ latents,
408
+ )
409
+
410
+ # EXTRA: prepare mask latents
411
+ mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
412
+ mask, masked_image_latents = self.prepare_mask_latents(
413
+ mask,
414
+ masked_image,
415
+ batch_size * num_images_per_prompt,
416
+ height,
417
+ width,
418
+ prompt_embeds.dtype,
419
+ device,
420
+ generator,
421
+ do_classifier_free_guidance,
422
+ )
423
+
424
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
425
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
426
+
427
+ # 8. Denoising loop
428
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
429
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
430
+ for i, t in enumerate(timesteps):
431
+ # expand the latents if we are doing classifier free guidance
432
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
433
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
434
+
435
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
436
+ latent_model_input,
437
+ t,
438
+ encoder_hidden_states=prompt_embeds,
439
+ controlnet_cond=control_image,
440
+ return_dict=False,
441
+ )
442
+
443
+ down_block_res_samples = [
444
+ down_block_res_sample * controlnet_conditioning_scale
445
+ for down_block_res_sample in down_block_res_samples
446
+ ]
447
+ mid_block_res_sample *= controlnet_conditioning_scale
448
+
449
+ # predict the noise residual
450
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
451
+ noise_pred = self.unet(
452
+ latent_model_input,
453
+ t,
454
+ encoder_hidden_states=prompt_embeds,
455
+ cross_attention_kwargs=cross_attention_kwargs,
456
+ down_block_additional_residuals=down_block_res_samples,
457
+ mid_block_additional_residual=mid_block_res_sample,
458
+ ).sample
459
+
460
+ # perform guidance
461
+ if do_classifier_free_guidance:
462
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
463
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
464
+
465
+ # compute the previous noisy sample x_t -> x_t-1
466
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
467
+
468
+ # call the callback, if provided
469
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
470
+ progress_bar.update()
471
+ if callback is not None and i % callback_steps == 0:
472
+ callback(i, t, latents)
473
+
474
+ # If we do sequential model offloading, let's offload unet and controlnet
475
+ # manually for max memory savings
476
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
477
+ self.unet.to("cpu")
478
+ self.controlnet.to("cpu")
479
+ torch.cuda.empty_cache()
480
+
481
+ if output_type == "latent":
482
+ image = latents
483
+ has_nsfw_concept = None
484
+ elif output_type == "pil":
485
+ # 8. Post-processing
486
+ image = self.decode_latents(latents)
487
+
488
+ # 9. Run safety checker
489
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
490
+
491
+ # 10. Convert to PIL
492
+ image = self.numpy_to_pil(image)
493
+ else:
494
+ # 8. Post-processing
495
+ image = self.decode_latents(latents)
496
+
497
+ # 9. Run safety checker
498
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
499
+
500
+ # Offload last model to CPU
501
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
502
+ self.final_offload_hook.offload()
503
+
504
+ if not return_dict:
505
+ return (image, has_nsfw_concept)
506
+
507
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)