sdxl-vae / README.md
playrth's picture
Update README.md
d2feaf6 verified
|
raw
history blame
4.62 kB
metadata
license: mit
tags:
  - stable-diffusion
  - stable-diffusion-diffusers
inference: false

SDXL - VAE

How to use with 🧨 diffusers

You can integrate this fine-tuned VAE decoder to your existing diffusers workflows, by including a vae argument to the StableDiffusionPipeline

from diffusers.models import AutoencoderKL
from diffusers import StableDiffusionPipeline

model = "stabilityai/your-stable-diffusion-model"
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
pipe = StableDiffusionPipeline.from_pretrained(model, vae=vae)

How to encode and decode Image example

import torch
from PIL import Image
from diffusers import AutoencoderKL
from diffusers.image_processor import VaeImageProcessor
import matplotlib.pyplot as plt

device=torch.device("cuda" if torch.cuda.is_available else "cpu")
# Load the pre-trained VAE model
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
vae.to(device)
vae.eval()

# Load Image processor
image_processor = VaeImageProcessor()

# Load an image
image = Image.open("Paste Image here")

# Preprocess the image
image_tensor =image_processor.preprocess(image,height=256,width=256,resize_mode="fill").to(device)

# Encode the image
with torch.no_grad():
    latent_representation = vae.encode(image_tensor).latent_dist.sample()

# Decode the latent representation back to image
with torch.no_grad():
    reconstructed_image = vae.decode(latent_representation).sample

# Convert the decoded tensor to a displayable image
reconstructed_image = reconstructed_image.cpu()
reconstructed_image=image_processor.postprocess(reconstructed_image,output_type='pil')
reconstructed_image=reconstructed_image[0]

# Plot the original and reconstructed images side by side
plt.figure(figsize=(10, 5))

# Original image
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")

# Reconstructed image
plt.subplot(1, 2, 2)
plt.imshow(reconstructed_image)
plt.title("Reconstructed Image")
plt.axis("off")

plt.show()

Model

SDXL is a latent diffusion model, where the diffusion operates in a pretrained, learned (and fixed) latent space of an autoencoder. While the bulk of the semantic composition is done by the latent diffusion model, we can improve local, high-frequency details in generated images by improving the quality of the autoencoder. To this end, we train the same autoencoder architecture used for the original Stable Diffusion at a larger batch-size (256 vs 9) and additionally track the weights with an exponential moving average (EMA). The resulting autoencoder outperforms the original model in all evaluated reconstruction metrics, see the table below.

Evaluation

SDXL-VAE vs original kl-f8 VAE vs f8-ft-MSE

COCO 2017 (256x256, val, 5000 images)

Model rFID PSNR SSIM PSIM Link Comments
SDXL-VAE 4.42 24.7 +/- 3.9 0.73 +/- 0.13 0.88 +/- 0.27 https://huggingface.co/stabilityai/sdxl-vae/blob/main/sdxl_vae.safetensors as used in SDXL
original 4.99 23.4 +/- 3.8 0.69 +/- 0.14 1.01 +/- 0.28 https://ommer-lab.com/files/latent-diffusion/kl-f8.zip as used in SD
ft-MSE 4.70 24.5 +/- 3.7 0.71 +/- 0.13 0.92 +/- 0.27 https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt resumed with EMA from ft-EMA, emphasis on MSE (rec. loss = MSE + 0.1 * LPIPS), smoother outputs