|
from typing import Callable, List, Optional, Union |
|
|
|
import PIL |
|
import torch |
|
from transformers import ( |
|
CLIPImageProcessor, |
|
CLIPSegForImageSegmentation, |
|
CLIPSegProcessor, |
|
CLIPTextModel, |
|
CLIPTokenizer, |
|
) |
|
|
|
from diffusers import DiffusionPipeline |
|
from diffusers.configuration_utils import FrozenDict |
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel |
|
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline |
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker |
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler |
|
from diffusers.utils import deprecate, is_accelerate_available, logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class TextInpainting(DiffusionPipeline): |
|
r""" |
|
Pipeline for text based inpainting using Stable Diffusion. |
|
Uses CLIPSeg to get a mask from the given text, then calls the Inpainting pipeline with the generated mask |
|
|
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the |
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) |
|
|
|
Args: |
|
segmentation_model ([`CLIPSegForImageSegmentation`]): |
|
CLIPSeg Model to generate mask from the given text. Please refer to the [model card]() for details. |
|
segmentation_processor ([`CLIPSegProcessor`]): |
|
CLIPSeg processor to get image, text features to translate prompt to English, if necessary. Please refer to the |
|
[model card](https://huggingface.co/docs/transformers/model_doc/clipseg) for details. |
|
vae ([`AutoencoderKL`]): |
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. |
|
text_encoder ([`CLIPTextModel`]): |
|
Frozen text-encoder. Stable Diffusion uses the text portion of |
|
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically |
|
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. |
|
tokenizer (`CLIPTokenizer`): |
|
Tokenizer of class |
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). |
|
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. |
|
scheduler ([`SchedulerMixin`]): |
|
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of |
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. |
|
safety_checker ([`StableDiffusionSafetyChecker`]): |
|
Classification module that estimates whether generated images could be considered offensive or harmful. |
|
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. |
|
feature_extractor ([`CLIPImageProcessor`]): |
|
Model that extracts features from generated images to be used as inputs for the `safety_checker`. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
segmentation_model: CLIPSegForImageSegmentation, |
|
segmentation_processor: CLIPSegProcessor, |
|
vae: AutoencoderKL, |
|
text_encoder: CLIPTextModel, |
|
tokenizer: CLIPTokenizer, |
|
unet: UNet2DConditionModel, |
|
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], |
|
safety_checker: StableDiffusionSafetyChecker, |
|
feature_extractor: CLIPImageProcessor, |
|
): |
|
super().__init__() |
|
|
|
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: |
|
deprecation_message = ( |
|
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" |
|
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " |
|
"to update the config accordingly as leaving `steps_offset` might led to incorrect results" |
|
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," |
|
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" |
|
" file" |
|
) |
|
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) |
|
new_config = dict(scheduler.config) |
|
new_config["steps_offset"] = 1 |
|
scheduler._internal_dict = FrozenDict(new_config) |
|
|
|
if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: |
|
deprecation_message = ( |
|
f"The configuration file of this scheduler: {scheduler} has not set the configuration" |
|
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" |
|
" sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" |
|
" incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" |
|
" Hub, it would be very nice if you could open a Pull request for the" |
|
" `scheduler/scheduler_config.json` file" |
|
) |
|
deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) |
|
new_config = dict(scheduler.config) |
|
new_config["skip_prk_steps"] = True |
|
scheduler._internal_dict = FrozenDict(new_config) |
|
|
|
if safety_checker is None: |
|
logger.warning( |
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" |
|
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" |
|
" results in services or applications open to the public. Both the diffusers team and Hugging Face" |
|
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" |
|
" it only for use-cases that involve analyzing network behavior or auditing its results. For more" |
|
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." |
|
) |
|
|
|
self.register_modules( |
|
segmentation_model=segmentation_model, |
|
segmentation_processor=segmentation_processor, |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
scheduler=scheduler, |
|
safety_checker=safety_checker, |
|
feature_extractor=feature_extractor, |
|
) |
|
|
|
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
|
r""" |
|
Enable sliced attention computation. |
|
|
|
When this option is enabled, the attention module will split the input tensor in slices, to compute attention |
|
in several steps. This is useful to save some memory in exchange for a small speed decrease. |
|
|
|
Args: |
|
slice_size (`str` or `int`, *optional*, defaults to `"auto"`): |
|
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If |
|
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, |
|
`attention_head_dim` must be a multiple of `slice_size`. |
|
""" |
|
if slice_size == "auto": |
|
|
|
|
|
slice_size = self.unet.config.attention_head_dim // 2 |
|
self.unet.set_attention_slice(slice_size) |
|
|
|
def disable_attention_slicing(self): |
|
r""" |
|
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go |
|
back to computing attention in one step. |
|
""" |
|
|
|
self.enable_attention_slicing(None) |
|
|
|
def enable_sequential_cpu_offload(self): |
|
r""" |
|
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, |
|
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a |
|
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. |
|
""" |
|
if is_accelerate_available(): |
|
from accelerate import cpu_offload |
|
else: |
|
raise ImportError("Please install accelerate via `pip install accelerate`") |
|
|
|
device = torch.device("cuda") |
|
|
|
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: |
|
if cpu_offloaded_model is not None: |
|
cpu_offload(cpu_offloaded_model, device) |
|
|
|
@property |
|
|
|
def _execution_device(self): |
|
r""" |
|
Returns the device on which the pipeline's models will be executed. After calling |
|
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module |
|
hooks. |
|
""" |
|
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): |
|
return self.device |
|
for module in self.unet.modules(): |
|
if ( |
|
hasattr(module, "_hf_hook") |
|
and hasattr(module._hf_hook, "execution_device") |
|
and module._hf_hook.execution_device is not None |
|
): |
|
return torch.device(module._hf_hook.execution_device) |
|
return self.device |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]], |
|
image: Union[torch.FloatTensor, PIL.Image.Image], |
|
text: str, |
|
height: int = 512, |
|
width: int = 512, |
|
num_inference_steps: int = 50, |
|
guidance_scale: float = 7.5, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: float = 0.0, |
|
generator: Optional[torch.Generator] = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
|
callback_steps: int = 1, |
|
**kwargs, |
|
): |
|
r""" |
|
Function invoked when calling the pipeline for generation. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`): |
|
The prompt or prompts to guide the image generation. |
|
image (`PIL.Image.Image`): |
|
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will |
|
be masked out with `mask_image` and repainted according to `prompt`. |
|
text (`str``): |
|
The text to use to generate the mask. |
|
height (`int`, *optional*, defaults to 512): |
|
The height in pixels of the generated image. |
|
width (`int`, *optional*, defaults to 512): |
|
The width in pixels of the generated image. |
|
num_inference_steps (`int`, *optional*, defaults to 50): |
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
|
expense of slower inference. |
|
guidance_scale (`float`, *optional*, defaults to 7.5): |
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). |
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen |
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > |
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, |
|
usually at the expense of lower image quality. |
|
negative_prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored |
|
if `guidance_scale` is less than `1`). |
|
num_images_per_prompt (`int`, *optional*, defaults to 1): |
|
The number of images to generate per prompt. |
|
eta (`float`, *optional*, defaults to 0.0): |
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to |
|
[`schedulers.DDIMScheduler`], will be ignored for others. |
|
generator (`torch.Generator`, *optional*): |
|
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation |
|
deterministic. |
|
latents (`torch.FloatTensor`, *optional*): |
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image |
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
|
tensor will ge generated by sampling using the supplied random `generator`. |
|
output_type (`str`, *optional*, defaults to `"pil"`): |
|
The output format of the generate image. Choose between |
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
|
plain tuple. |
|
callback (`Callable`, *optional*): |
|
A function that will be called every `callback_steps` steps during inference. The function will be |
|
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. |
|
callback_steps (`int`, *optional*, defaults to 1): |
|
The frequency at which the `callback` function will be called. If not specified, the callback will be |
|
called at every step. |
|
|
|
Returns: |
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: |
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. |
|
When returning a tuple, the first element is a list with the generated images, and the second element is a |
|
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" |
|
(nsfw) content, according to the `safety_checker`. |
|
""" |
|
|
|
|
|
inputs = self.segmentation_processor( |
|
text=[text], images=[image], padding="max_length", return_tensors="pt" |
|
).to(self.device) |
|
outputs = self.segmentation_model(**inputs) |
|
mask = torch.sigmoid(outputs.logits).cpu().detach().unsqueeze(-1).numpy() |
|
mask_pil = self.numpy_to_pil(mask)[0].resize(image.size) |
|
|
|
|
|
inpainting_pipeline = StableDiffusionInpaintPipeline( |
|
vae=self.vae, |
|
text_encoder=self.text_encoder, |
|
tokenizer=self.tokenizer, |
|
unet=self.unet, |
|
scheduler=self.scheduler, |
|
safety_checker=self.safety_checker, |
|
feature_extractor=self.feature_extractor, |
|
) |
|
return inpainting_pipeline( |
|
prompt=prompt, |
|
image=image, |
|
mask_image=mask_pil, |
|
height=height, |
|
width=width, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
negative_prompt=negative_prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
eta=eta, |
|
generator=generator, |
|
latents=latents, |
|
output_type=output_type, |
|
return_dict=return_dict, |
|
callback=callback, |
|
callback_steps=callback_steps, |
|
) |
|
|