Safe Stable Diffusion
Safe Stable Diffusion was proposed in Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models and mitigates the well known issue that models like Stable Diffusion that are trained on unfiltered, web-crawled datasets tend to suffer from inappropriate degeneration. For instance Stable Diffusion may unexpectedly generate nudity, violence, images depicting self-harm, or otherwise offensive content. Safe Stable Diffusion is an extension to the Stable Diffusion that drastically reduces content like this.
The abstract of the paper is the following:
Text-conditioned image generation models have recently achieved astonishing results in image quality and text alignment and are consequently employed in a fast-growing number of applications. Since they are highly data-driven, relying on billion-sized datasets randomly scraped from the internet, they also suffer, as we demonstrate, from degenerated and biased human behavior. In turn, they may even reinforce such biases. To help combat these undesired side effects, we present safe latent diffusion (SLD). Specifically, to measure the inappropriate degeneration due to unfiltered and imbalanced training sets, we establish a novel image generation test bed-inappropriate image prompts (I2P)-containing dedicated, real-world image-to-text prompts covering concepts such as nudity and violence. As our exhaustive empirical evaluation demonstrates, the introduced SLD removes and suppresses inappropriate image parts during the diffusion process, with no additional training required and no adverse effect on overall image quality or text alignment.
Overview:
Pipeline | Tasks | Colab | Demo |
---|---|---|---|
pipeline_stable_diffusion_safe.py | Text-to-Image Generation |
Tips
- Safe Stable Diffusion may also be used with weights of Stable Diffusion.
Run Safe Stable Diffusion
Safe Stable Diffusion can be tested very easily with the StableDiffusionPipelineSafe, and the "AIML-TUDA/stable-diffusion-safe"
checkpoint exactly in the same way it is shown in the Conditional Image Generation Guide.
Interacting with the Safety Concept
To check and edit the currently used safety concept, use the safety_concept
property of StableDiffusionPipelineSafe:
>>> from diffusers import StableDiffusionPipelineSafe
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
>>> pipeline.safety_concept
For each image generation the active concept is also contained in StableDiffusionSafePipelineOutput
.
Using pre-defined safety configurations
You may use the 4 configurations defined in the Safe Latent Diffusion paper as follows:
>>> from diffusers import StableDiffusionPipelineSafe
>>> from diffusers.pipelines.stable_diffusion_safe import SafetyConfig
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
>>> prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker"
>>> out = pipeline(prompt=prompt, **SafetyConfig.MAX)
The following configurations are available: SafetyConfig.WEAK
, SafetyConfig.MEDIUM
, SafetyConfig.STRONG
, and SafetyConfig.MAX
.
How to load and use different schedulers
The safe stable diffusion pipeline uses PNDMScheduler scheduler by default. But diffusers
provides many other schedulers that can be used with the stable diffusion pipeline such as DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler etc.
To use a different scheduler, you can either change it via the ConfigMixin.from_config() method or pass the scheduler
argument to the from_pretrained
method of the pipeline. For example, to use the EulerDiscreteScheduler, you can do the following:
>>> from diffusers import StableDiffusionPipelineSafe, EulerDiscreteScheduler
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> # or
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("AIML-TUDA/stable-diffusion-safe", subfolder="scheduler")
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained(
... "AIML-TUDA/stable-diffusion-safe", scheduler=euler_scheduler
... )
StableDiffusionSafePipelineOutput
class diffusers.pipelines.stable_diffusion_safe.StableDiffusionSafePipelineOutput
< source >( images: typing.Union[typing.List[PIL.Image.Image], numpy.ndarray] nsfw_content_detected: typing.Optional[typing.List[bool]] unsafe_images: typing.Union[typing.List[PIL.Image.Image], numpy.ndarray, NoneType] applied_safety_concept: typing.Optional[str] )
Parameters
-
images (
List[PIL.Image.Image]
ornp.ndarray
) — List of denoised PIL images of lengthbatch_size
or numpy array of shape(batch_size, height, width, num_channels)
. PIL images or numpy array present the denoised images of the diffusion pipeline. -
nsfw_content_detected (
List[bool]
) — List of flags denoting whether the corresponding generated image likely represents “not-safe-for-work” (nsfw) content, orNone
if safety checking could not be performed. -
images (
List[PIL.Image.Image]
ornp.ndarray
) — List of denoised PIL images that were flagged by the safety checker any may contain “not-safe-for-work” (nsfw) content, orNone
if no safety check was performed or no images were flagged. -
applied_safety_concept (
str
) — The safety concept that was applied for safety guidance, orNone
if safety guidance was disabled
Output class for Safe Stable Diffusion pipelines.
StableDiffusionPipelineSafe
class diffusers.StableDiffusionPipelineSafe
< source >( vae: AutoencoderKL text_encoder: CLIPTextModel tokenizer: CLIPTokenizer unet: UNet2DConditionModel scheduler: KarrasDiffusionSchedulers safety_checker: SafeStableDiffusionSafetyChecker feature_extractor: CLIPImageProcessor requires_safety_checker: bool = True )
Parameters
- vae (AutoencoderKL) — Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
-
text_encoder (
CLIPTextModel
) — Frozen text-encoder. Stable Diffusion uses the text portion of CLIP, specifically the clip-vit-large-patch14 variant. -
tokenizer (
CLIPTokenizer
) — Tokenizer of class CLIPTokenizer. - unet (UNet2DConditionModel) — Conditional U-Net architecture to denoise the encoded image latents.
-
scheduler (SchedulerMixin) —
A scheduler to be used in combination with
unet
to denoise the encoded image latents. Can be one of DDIMScheduler, LMSDiscreteScheduler, or PNDMScheduler. -
safety_checker (
StableDiffusionSafetyChecker
) — Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the model card for details. -
feature_extractor (
CLIPImageProcessor
) — Model that extracts features from generated images to be used as inputs for thesafety_checker
.
Pipeline for text-to-image generation using Safe Latent Diffusion.
The implementation is based on the StableDiffusionPipeline
This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
__call__
< source >(
prompt: typing.Union[str, typing.List[str]]
height: typing.Optional[int] = None
width: typing.Optional[int] = None
num_inference_steps: int = 50
guidance_scale: float = 7.5
negative_prompt: typing.Union[str, typing.List[str], NoneType] = None
num_images_per_prompt: typing.Optional[int] = 1
eta: float = 0.0
generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None
latents: typing.Optional[torch.FloatTensor] = None
output_type: typing.Optional[str] = 'pil'
return_dict: bool = True
callback: typing.Union[typing.Callable[[int, int, torch.FloatTensor], NoneType], NoneType] = None
callback_steps: int = 1
sld_guidance_scale: typing.Optional[float] = 1000
sld_warmup_steps: typing.Optional[int] = 10
sld_threshold: typing.Optional[float] = 0.01
sld_momentum_scale: typing.Optional[float] = 0.3
sld_mom_beta: typing.Optional[float] = 0.4
)
→
StableDiffusionPipelineOutput or tuple
Parameters
-
prompt (
str
orList[str]
) — The prompt or prompts to guide the image generation. -
height (
int
, optional, defaults to self.unet.config.sample_size * self.vae_scale_factor) — The height in pixels of the generated image. -
width (
int
, optional, defaults to self.unet.config.sample_size * self.vae_scale_factor) — The width in pixels of the generated image. -
num_inference_steps (
int
, optional, defaults to 50) — The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. -
guidance_scale (
float
, optional, defaults to 7.5) — Guidance scale as defined in Classifier-Free Diffusion Guidance.guidance_scale
is defined asw
of equation 2. of Imagen Paper. Guidance scale is enabled by settingguidance_scale > 1
. Higher guidance scale encourages to generate images that are closely linked to the textprompt
, usually at the expense of lower image quality. -
negative_prompt (
str
orList[str]
, optional) — The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored ifguidance_scale
is less than1
). -
num_images_per_prompt (
int
, optional, defaults to 1) — The number of images to generate per prompt. -
eta (
float
, optional, defaults to 0.0) — Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to schedulers.DDIMScheduler, will be ignored for others. -
generator (
torch.Generator
, optional) — One or a list of torch generator(s) to make generation deterministic. -
latents (
torch.FloatTensor
, optional) — Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied randomgenerator
. -
output_type (
str
, optional, defaults to"pil"
) — The output format of the generate image. Choose between PIL:PIL.Image.Image
ornp.array
. -
return_dict (
bool
, optional, defaults toTrue
) — Whether or not to return a StableDiffusionPipelineOutput instead of a plain tuple. -
callback (
Callable
, optional) — A function that will be called everycallback_steps
steps during inference. The function will be called with the following arguments:callback(step: int, timestep: int, latents: torch.FloatTensor)
. -
callback_steps (
int
, optional, defaults to 1) — The frequency at which thecallback
function will be called. If not specified, the callback will be called at every step. -
sld_guidance_scale (
float
, optional, defaults to 1000) — Safe latent guidance as defined in Safe Latent Diffusion.sld_guidance_scale
is defined as sS of Eq. 6. If set to be less than 1, safety guidance will be disabled. -
sld_warmup_steps (
int
, optional, defaults to 10) — Number of warmup steps for safety guidance. SLD will only be applied for diffusion steps greater thansld_warmup_steps
.sld_warmup_steps
is defined asdelta
of Safe Latent Diffusion. -
sld_threshold (
float
, optional, defaults to 0.01) — Threshold that separates the hyperplane between appropriate and inappropriate images.sld_threshold
is defined aslamda
of Eq. 5 in Safe Latent Diffusion. -
sld_momentum_scale (
float
, optional, defaults to 0.3) — Scale of the SLD momentum to be added to the safety guidance at each diffusion step. If set to 0.0 momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller thansld_warmup_steps
.sld_momentum_scale
is defined assm
of Eq. 7 in Safe Latent Diffusion. -
sld_mom_beta (
float
, optional, defaults to 0.4) — Defines how safety guidance momentum builds up.sld_mom_beta
indicates how much of the previous momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller thansld_warmup_steps
.sld_mom_beta
is defined asbeta m
of Eq. 8 in Safe Latent Diffusion.
Returns
StableDiffusionPipelineOutput or tuple
StableDiffusionPipelineOutput if return_dict
is True, otherwise a tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of
bools denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the
safety_checker`.
Function invoked when calling the pipeline for generation.
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
torch.device('meta') and loaded to GPU only when their specific submodule has its
forward` method called.