File size: 11,140 Bytes
47a3cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.autograd import grad
import argparse
from tqdm import tqdm
from syncdiffusion.utils import *
import lpips
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
class SyncDiffusion(nn.Module):
def __init__(self, device='cuda', sd_version='2.0', hf_key=None):
super().__init__()
self.device = device
self.sd_version = sd_version
print(f'[INFO] loading stable diffusion...')
if hf_key is not None:
print(f'[INFO] using hugging face custom model key: {hf_key}')
model_key = hf_key
elif self.sd_version == '2.1':
model_key = "stabilityai/stable-diffusion-2-1-base"
elif self.sd_version == '2.0':
model_key = "stabilityai/stable-diffusion-2-base"
elif self.sd_version == '1.5':
model_key = "runwayml/stable-diffusion-v1-5"
else:
raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
# Load pretrained models from HuggingFace
self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device)
self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device)
# Freeze models
for p in self.unet.parameters():
p.requires_grad_(False)
for p in self.vae.parameters():
p.requires_grad_(False)
for p in self.text_encoder.parameters():
p.requires_grad_(False)
self.unet.eval()
self.vae.eval()
self.text_encoder.eval()
print(f'[INFO] loaded stable diffusion!')
# Set DDIM scheduler
self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
# load perceptual loss (LPIPS)
self.percept_loss = lpips.LPIPS(net='vgg').to(self.device)
print(f'[INFO] loaded perceptual loss!')
def get_text_embeds(self, prompt, negative_prompt):
# Tokenize text and get embeddings
text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
truncation=True, return_tensors='pt')
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
# Repeat for unconditional embeddings
uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
return_tensors='pt')
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# Concatenate for final embeddings
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
imgs = self.vae.decode(latents).sample
imgs = (imgs / 2 + 0.5).clamp(0, 1)
return imgs
def sample_syncdiffusion(
self,
prompts,
negative_prompts="",
height=512,
width=2048,
latent_size=64, # fix latent size to 64 for Stable Diffusion
num_inference_steps=50,
guidance_scale=7.5,
sync_weight=20, # gradient descent weight 'w' in the paper
sync_freq=1, # sync_freq=n: perform gradient descent every n steps
sync_thres=50, # sync_thres=n: compute SyncDiffusion only for the first n steps
sync_decay_rate=0.95, # decay rate for sync_weight, set as 0.95 in the paper
stride=16, # stride for latents, set as 16 in the paper
):
assert height >= 512 and width >= 512, 'height and width must be at least 512'
assert height % (stride * 8) == 0 and width % (stride * 8) == 0, 'height and width must be divisible by the stride multiplied by 8'
assert stride % 8 == 0 and stride < 64, 'stride must be divisible by 8 and smaller than the latent size of Stable Diffusion'
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(negative_prompts, str):
negative_prompts = [negative_prompts]
# obtain text embeddings
text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2, 77, 768]
# define a list of windows to process in parallel
views = get_views(height, width, stride=stride)
print(f"[INFO] number of views to process: {len(views)}")
# Initialize latent
latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8))
count = torch.zeros_like(latent, requires_grad=False, device=self.device)
value = torch.zeros_like(latent, requires_grad=False, device=self.device)
latent = latent.to(self.device)
# set DDIM scheduler
self.scheduler.set_timesteps(num_inference_steps)
# set the anchor view as the middle view
anchor_view_idx = len(views) // 2
# set SyncDiffusion scheduler
sync_scheduler = exponential_decay_list(
init_weight=sync_weight,
decay_rate=sync_decay_rate,
num_steps=num_inference_steps
)
print(f'[INFO] using exponential decay scheduler with decay rate {sync_decay_rate}')
with torch.autocast('cuda'):
for i, t in enumerate(tqdm(self.scheduler.timesteps)):
count.zero_()
value.zero_()
'''
(1) First, obtain the reference anchor view (for computing the perceptual loss)
'''
with torch.no_grad():
if (i + 1) % sync_freq == 0 and i < sync_thres:
# decode the anchor view
h_start, h_end, w_start, w_end = views[anchor_view_idx]
latent_view = latent[:, :, h_start:h_end, w_start:w_end].detach()
latent_model_input = torch.cat([latent_view] * 2) # 2 x 4 x 64 x 64
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']
# perform guidance
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred_new = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
# predict the 'foreseen denoised' latent (x0) of the anchor view
latent_pred_x0 = self.scheduler.step(noise_pred_new, t, latent_view)["pred_original_sample"]
decoded_image_anchor = self.decode_latents(latent_pred_x0) # 1 x 3 x 512 x 512
'''
(2) Then perform SyncDiffusion and run a single denoising step
'''
for view_idx, (h_start, h_end, w_start, w_end) in enumerate(views):
latent_view = latent[:, :, h_start:h_end, w_start:w_end].detach()
############################## BEGIN: PERFORM GRADIENT DESCENT (SyncDiffusion) ##############################
latent_view_copy = latent_view.clone().detach()
#### TODO: TEST ####
# if i % sync_freq == 0 and i < sync_thres:
if (i + 1) % sync_freq == 0 and i < sync_thres:
# gradient on latent_view
latent_view = latent_view.requires_grad_()
# expand the latents for classifier-free guidance
latent_model_input = torch.cat([latent_view] * 2)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']
# perform guidance
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred_new = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
# compute the denoising step with the reference model
out = self.scheduler.step(noise_pred_new, t, latent_view)
# predict the 'foreseen denoised' latent (x0)
latent_view_x0 = out['pred_original_sample']
# decode the denoised latent
decoded_x0 = self.decode_latents(latent_view_x0) # 1 x 3 x 512 x 512
# compute the perceptual loss (LPIPS)
percept_loss = self.percept_loss(
decoded_x0 * 2.0 - 1.0,
decoded_image_anchor * 2.0 - 1.0
)
# compute the gradient of the perceptual loss w.r.t. the latent
norm_grad = grad(outputs=percept_loss, inputs=latent_view)[0]
# SyncDiffusion: update the original latent
if view_idx != anchor_view_idx:
latent_view_copy = latent_view_copy - sync_scheduler[i] * norm_grad # 1 x 4 x 64 x 64
############################## END: PERFORM GRADIENT DESCENT (SyncDiffusion) ##############################
# after gradient descent, perform a single denoising step
with torch.no_grad():
latent_model_input = torch.cat([latent_view_copy] * 2)
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred_new = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
out = self.scheduler.step(noise_pred_new, t, latent_view_copy)
latent_view_denoised = out['prev_sample']
# merge the latent views
value[:, :, h_start:h_end, w_start:w_end] += latent_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
# take the MultiDiffusion step (average the latents)
latent = torch.where(count > 0, value / count, value)
# decode latents to panorama image
with torch.no_grad():
imgs = self.decode_latents(latent) # [1, 3, 512, 512]
img = T.ToPILImage()(imgs[0].cpu())
print(f"[INFO] Done!")
return img
|