# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import ( FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) from diffusers.models import ( AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, ) from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.pipelines.stable_diffusion_xl.pipeline_output import ( StableDiffusionXLPipelineOutput, ) from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils.torch_utils import ( is_compiled_module, is_torch_version, randn_tensor, ) from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, **kwargs, ): scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionXLRecolorPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin, ): # leave controlnet out on purpose because it iterates with unet model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "feature_extractor", "image_encoder", ] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "add_text_embeds", "add_time_ids", "negative_pooled_prompt_embeds", "negative_add_time_ids", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, controlnet: Union[ ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel, ], scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() if isinstance(controlnet, (list, tuple)): controlnet = MultiControlNetModel(controlnet) self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, controlnet=controlnet, scheduler=scheduler, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True ) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False, ) self.register_to_config( force_zeros_for_empty_prompt=force_zeros_for_empty_prompt ) def encode_prompt( self, prompt: str, negative_prompt: Optional[str] = None, device: Optional[torch.device] = None, do_classifier_free_guidance: bool = True, ): device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) # Define tokenizers and text encoders tokenizers = ( [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] ) text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) prompt_2 = prompt # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder( text_input_ids.to(device), output_hidden_states=True ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance negative_prompt_embeds = None negative_pooled_prompt_embeds = None if do_classifier_free_guidance: negative_prompt = negative_prompt or "" negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) # normalize str to list negative_prompt = [negative_prompt] negative_prompt_2 = negative_prompt uncond_tokens: List[str] uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip( uncond_tokens, tokenizers, text_encoders ): max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to( dtype=self.text_encoder_2.dtype, device=device ) negative_prompt_embeds = negative_prompt_embeds.view( batch_size, seq_len, -1 ) pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view( bs_embed, -1 ) return ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image( self, image, device, num_images_per_prompt, output_hidden_states=None ): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder( image, output_hidden_states=True ).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = ( uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, device, do_classifier_free_guidance, ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len( self.unet.encoder_hid_proj.image_projection_layers ): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): if do_classifier_free_guidance: single_image_embeds = torch.cat( [negative_image_embeds[i], single_image_embeds], dim=0 ) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds def prepare_image(self, image, device, dtype, do_classifier_free_guidance=False): image = self.control_image_processor.preprocess(image).to(dtype=torch.float32) image_batch_size = image.shape[0] image = image.repeat_interleave(image_batch_size, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance: image = torch.cat([image] * 2) return image def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device ): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) latents = randn_tensor(shape, device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @property def guidance_scale(self): return self._guidance_scale # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def denoising_end(self): return self._denoising_end @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() def __call__( self, image: PipelineImageInput = None, num_inference_steps: int = 8, guidance_scale: float = 2.0, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, **kwargs, ): controlnet = self.controlnet # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance( control_guidance_end, list ): control_guidance_start = len(control_guidance_end) * [ control_guidance_start ] elif not isinstance(control_guidance_end, list) and isinstance( control_guidance_start, list ): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance( control_guidance_end, list ): mult = ( len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 ) control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], ) self._guidance_scale = guidance_scale # 2. Define call parameters batch_size = 1 device = self._execution_device if isinstance(controlnet, MultiControlNetModel) and isinstance( controlnet_conditioning_scale, float ): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len( controlnet.nets ) # 3.2 Encode ip_adapter_image if ip_adapter_image is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, device, self.do_classifier_free_guidance, ) # 4. Prepare image if isinstance(controlnet, ControlNetModel): image = self.prepare_image( image=image, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, ) height, width = image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): images = [] for image_ in image: image_ = self.prepare_image( image=image_, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, ) images.append(image_) image = images height, width = image[0].shape[-2:] else: assert False # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device ) self._num_timesteps = len(timesteps) # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size, num_channels_latents, height, width, prompt_embeds.dtype, device, ) # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append( keeps[0] if isinstance(controlnet, ControlNetModel) else keeps ) # 7.2 Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds add_time_ids = negative_add_time_ids = torch.tensor( image[0].shape[-2:] + torch.Size([0, 0]) + image[0].shape[-2:] ).unsqueeze(0) negative_add_time_ids = add_time_ids if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat( [negative_pooled_prompt_embeds, add_text_embeds], dim=0 ) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device) added_cond_kwargs = { "text_embeds": add_text_embeds, "time_ids": add_time_ids, } # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents ) latent_model_input = self.scheduler.scale_model_input( latent_model_input, t ) # controlnet(s) inference control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds controlnet_added_cond_kwargs = added_cond_kwargs if isinstance(controlnet_keep[i], list): cond_scale = [ c * s for c, s in zip( controlnet_conditioning_scale, controlnet_keep[i] ) ] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, conditioning_scale=cond_scale, guess_mode=False, added_cond_kwargs=controlnet_added_cond_kwargs, return_dict=False, ) if ip_adapter_image is not None: added_cond_kwargs["image_embeds"] = image_embeds # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=None, cross_attention_kwargs={}, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, return_dict=False )[0] if i == 2: prompt_embeds = prompt_embeds[-1:] add_text_embeds = add_text_embeds[-1:] add_time_ids = add_time_ids[-1:] added_cond_kwargs = { "text_embeds": add_text_embeds, "time_ids": add_time_ids, } controlnet_prompt_embeds = prompt_embeds controlnet_added_cond_kwargs = added_cond_kwargs image = [single_image[-1:] for single_image in image] self._guidance_scale = 0.0 # call the callback, if provided if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 ): progress_bar.update() latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image)[0] # Offload all models self.maybe_free_model_hooks() return StableDiffusionXLPipelineOutput(images=image)