phillipinseoul commited on
Commit
47a3cb0
1 Parent(s): 31d5771

add app.py

Browse files
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py
3
+ An interactive demo for text-guided panorama generation.
4
+ """
5
+ import torch
6
+ import gradio as gr
7
+
8
+ from syncdiffusion.syncdiffusion_model import SyncDiffusion
9
+ from syncdiffusion.utils import seed_everything
10
+
11
+ def run_inference(
12
+ prompt: str,
13
+ height: int = 512,
14
+ width: int = 2048,
15
+ sync_weight: float = 20.0,
16
+ sync_decay_rate: float = 0.95,
17
+ sync_freq: int = 1,
18
+ sync_thres: int = 5,
19
+ seed: int = 0
20
+ ):
21
+ # set device
22
+ device = torch.device = torch.device("cuda")
23
+
24
+ # set random seed
25
+ seed_everything(seed)
26
+
27
+ # load SyncDiffusion model
28
+ syncdiffusion = SyncDiffusion(device, sd_version="2.0")
29
+
30
+ img = syncdiffusion.sample_syncdiffusion(
31
+ prompts = prompt,
32
+ negative_prompts = "",
33
+ height = height,
34
+ width = width,
35
+ num_inference_steps = 50,
36
+ guidance_scale = 7.5,
37
+ sync_weight = sync_weight,
38
+ sync_decay_rate = sync_decay_rate,
39
+ sync_freq = sync_freq,
40
+ sync_thres = sync_thres,
41
+ stride = 16
42
+ )
43
+ return [img]
44
+
45
+ if __name__=="__main__":
46
+ title = "SyncDiffusion: Text-Guided Panorama Generation"
47
+
48
+ description_text = '''
49
+ This demo features text-guided panorama generation from our work <a href="https://arxiv.org/abs/2306.05178">SyncDiffusion: Coherent Montage via Synchronized Joint Diffusions, NeurIPS 2023</a>.
50
+ Please refer to our <a href="https://syncdiffusion.github.io/">project page</a> for details.
51
+ '''
52
+
53
+ # create UI
54
+ with gr.Blocks(title=title) as demo:
55
+
56
+ # description of demo
57
+ gr.Markdown(description_text)
58
+
59
+ # inputs
60
+ with gr.Row():
61
+ with gr.Column():
62
+ run_button = gr.Button(label="Generate")
63
+
64
+ prompt = gr.Textbox(label="Text Prompt", value='a cinematic view of a castle in the sunset')
65
+ width = gr.Slider(label="Width", minimum=512, maximum=4096, value=2048, step=128)
66
+ sync_weight = gr.Slider(label="Sync Weight", minimum=0.0, maximum=30.0, value=20.0, step=5.0)
67
+ sync_thres = gr.Slider(label="Sync Threshold (If N, apply SyncDiffusion for the first N steps)", minimum=0, maximum=50, value=5, step=1)
68
+ seed = gr.Number(label="Seed", value=0)
69
+
70
+ with gr.Column():
71
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
72
+
73
+ # display examples
74
+ examples = gr.Examples(
75
+ examples=[
76
+ 'a cinematic view of a castle in the sunset',
77
+ 'natural landscape in anime style illustration',
78
+ 'a photo of a lake under the northern lights',
79
+ ],
80
+ inputs=[prompt],
81
+ )
82
+
83
+ ips = [prompt, width, sync_weight, sync_thres, seed]
84
+ run_button.click(fn=run_inference(), inputs=ips, outputs=[result_gallery])
85
+
86
+ demo.queue(max_size=30)
87
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.12.1
2
+ torchvision==0.13.1
3
+ transformers==4.28
4
+ diffusers==0.15.1
5
+ opencv-python==4.5.1.48
6
+ lpips
7
+ accelerate
8
+ tqdm
9
+ gradio
syncdiffusion/__pycache__/model.cpython-38.pyc ADDED
Binary file (5.81 kB). View file
 
syncdiffusion/__pycache__/syncdiffusion.cpython-38.pyc ADDED
Binary file (5.75 kB). View file
 
syncdiffusion/__pycache__/syncdiffusion_model.cpython-38.pyc ADDED
Binary file (5.53 kB). View file
 
syncdiffusion/__pycache__/syncdiffusion_model.cpython-39.pyc ADDED
Binary file (5.63 kB). View file
 
syncdiffusion/__pycache__/utils.cpython-38.pyc ADDED
Binary file (1.3 kB). View file
 
syncdiffusion/__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.29 kB). View file
 
syncdiffusion/syncdiffusion_model.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.transforms as T
4
+ from torch.autograd import grad
5
+ import argparse
6
+ from tqdm import tqdm
7
+
8
+ from syncdiffusion.utils import *
9
+ import lpips
10
+ from transformers import CLIPTextModel, CLIPTokenizer
11
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
12
+
13
+ class SyncDiffusion(nn.Module):
14
+ def __init__(self, device='cuda', sd_version='2.0', hf_key=None):
15
+ super().__init__()
16
+
17
+ self.device = device
18
+ self.sd_version = sd_version
19
+
20
+ print(f'[INFO] loading stable diffusion...')
21
+ if hf_key is not None:
22
+ print(f'[INFO] using hugging face custom model key: {hf_key}')
23
+ model_key = hf_key
24
+ elif self.sd_version == '2.1':
25
+ model_key = "stabilityai/stable-diffusion-2-1-base"
26
+ elif self.sd_version == '2.0':
27
+ model_key = "stabilityai/stable-diffusion-2-base"
28
+ elif self.sd_version == '1.5':
29
+ model_key = "runwayml/stable-diffusion-v1-5"
30
+ else:
31
+ raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
32
+
33
+ # Load pretrained models from HuggingFace
34
+ self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device)
35
+ self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
36
+ self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
37
+ self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device)
38
+
39
+ # Freeze models
40
+ for p in self.unet.parameters():
41
+ p.requires_grad_(False)
42
+ for p in self.vae.parameters():
43
+ p.requires_grad_(False)
44
+ for p in self.text_encoder.parameters():
45
+ p.requires_grad_(False)
46
+
47
+ self.unet.eval()
48
+ self.vae.eval()
49
+ self.text_encoder.eval()
50
+ print(f'[INFO] loaded stable diffusion!')
51
+
52
+ # Set DDIM scheduler
53
+ self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
54
+
55
+ # load perceptual loss (LPIPS)
56
+ self.percept_loss = lpips.LPIPS(net='vgg').to(self.device)
57
+ print(f'[INFO] loaded perceptual loss!')
58
+
59
+ def get_text_embeds(self, prompt, negative_prompt):
60
+ # Tokenize text and get embeddings
61
+ text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
62
+ truncation=True, return_tensors='pt')
63
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
64
+
65
+ # Repeat for unconditional embeddings
66
+ uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
67
+ return_tensors='pt')
68
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
69
+
70
+ # Concatenate for final embeddings
71
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
72
+ return text_embeddings
73
+
74
+ def decode_latents(self, latents):
75
+ latents = 1 / 0.18215 * latents
76
+ imgs = self.vae.decode(latents).sample
77
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
78
+ return imgs
79
+
80
+ def sample_syncdiffusion(
81
+ self,
82
+ prompts,
83
+ negative_prompts="",
84
+ height=512,
85
+ width=2048,
86
+ latent_size=64, # fix latent size to 64 for Stable Diffusion
87
+ num_inference_steps=50,
88
+ guidance_scale=7.5,
89
+ sync_weight=20, # gradient descent weight 'w' in the paper
90
+ sync_freq=1, # sync_freq=n: perform gradient descent every n steps
91
+ sync_thres=50, # sync_thres=n: compute SyncDiffusion only for the first n steps
92
+ sync_decay_rate=0.95, # decay rate for sync_weight, set as 0.95 in the paper
93
+ stride=16, # stride for latents, set as 16 in the paper
94
+ ):
95
+ assert height >= 512 and width >= 512, 'height and width must be at least 512'
96
+ assert height % (stride * 8) == 0 and width % (stride * 8) == 0, 'height and width must be divisible by the stride multiplied by 8'
97
+ assert stride % 8 == 0 and stride < 64, 'stride must be divisible by 8 and smaller than the latent size of Stable Diffusion'
98
+
99
+ if isinstance(prompts, str):
100
+ prompts = [prompts]
101
+
102
+ if isinstance(negative_prompts, str):
103
+ negative_prompts = [negative_prompts]
104
+
105
+ # obtain text embeddings
106
+ text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2, 77, 768]
107
+
108
+ # define a list of windows to process in parallel
109
+ views = get_views(height, width, stride=stride)
110
+ print(f"[INFO] number of views to process: {len(views)}")
111
+
112
+ # Initialize latent
113
+ latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8))
114
+
115
+ count = torch.zeros_like(latent, requires_grad=False, device=self.device)
116
+ value = torch.zeros_like(latent, requires_grad=False, device=self.device)
117
+ latent = latent.to(self.device)
118
+
119
+ # set DDIM scheduler
120
+ self.scheduler.set_timesteps(num_inference_steps)
121
+
122
+ # set the anchor view as the middle view
123
+ anchor_view_idx = len(views) // 2
124
+
125
+ # set SyncDiffusion scheduler
126
+ sync_scheduler = exponential_decay_list(
127
+ init_weight=sync_weight,
128
+ decay_rate=sync_decay_rate,
129
+ num_steps=num_inference_steps
130
+ )
131
+ print(f'[INFO] using exponential decay scheduler with decay rate {sync_decay_rate}')
132
+
133
+ with torch.autocast('cuda'):
134
+ for i, t in enumerate(tqdm(self.scheduler.timesteps)):
135
+ count.zero_()
136
+ value.zero_()
137
+
138
+ '''
139
+ (1) First, obtain the reference anchor view (for computing the perceptual loss)
140
+ '''
141
+ with torch.no_grad():
142
+ if (i + 1) % sync_freq == 0 and i < sync_thres:
143
+ # decode the anchor view
144
+ h_start, h_end, w_start, w_end = views[anchor_view_idx]
145
+ latent_view = latent[:, :, h_start:h_end, w_start:w_end].detach()
146
+
147
+ latent_model_input = torch.cat([latent_view] * 2) # 2 x 4 x 64 x 64
148
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']
149
+
150
+ # perform guidance
151
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
152
+ noise_pred_new = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
153
+
154
+ # predict the 'foreseen denoised' latent (x0) of the anchor view
155
+ latent_pred_x0 = self.scheduler.step(noise_pred_new, t, latent_view)["pred_original_sample"]
156
+ decoded_image_anchor = self.decode_latents(latent_pred_x0) # 1 x 3 x 512 x 512
157
+
158
+ '''
159
+ (2) Then perform SyncDiffusion and run a single denoising step
160
+ '''
161
+ for view_idx, (h_start, h_end, w_start, w_end) in enumerate(views):
162
+ latent_view = latent[:, :, h_start:h_end, w_start:w_end].detach()
163
+
164
+ ############################## BEGIN: PERFORM GRADIENT DESCENT (SyncDiffusion) ##############################
165
+ latent_view_copy = latent_view.clone().detach()
166
+
167
+ #### TODO: TEST ####
168
+ # if i % sync_freq == 0 and i < sync_thres:
169
+ if (i + 1) % sync_freq == 0 and i < sync_thres:
170
+
171
+ # gradient on latent_view
172
+ latent_view = latent_view.requires_grad_()
173
+
174
+ # expand the latents for classifier-free guidance
175
+ latent_model_input = torch.cat([latent_view] * 2)
176
+
177
+ # predict the noise residual
178
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']
179
+
180
+ # perform guidance
181
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
182
+ noise_pred_new = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
183
+
184
+ # compute the denoising step with the reference model
185
+ out = self.scheduler.step(noise_pred_new, t, latent_view)
186
+
187
+ # predict the 'foreseen denoised' latent (x0)
188
+ latent_view_x0 = out['pred_original_sample']
189
+
190
+ # decode the denoised latent
191
+ decoded_x0 = self.decode_latents(latent_view_x0) # 1 x 3 x 512 x 512
192
+
193
+ # compute the perceptual loss (LPIPS)
194
+ percept_loss = self.percept_loss(
195
+ decoded_x0 * 2.0 - 1.0,
196
+ decoded_image_anchor * 2.0 - 1.0
197
+ )
198
+
199
+ # compute the gradient of the perceptual loss w.r.t. the latent
200
+ norm_grad = grad(outputs=percept_loss, inputs=latent_view)[0]
201
+
202
+ # SyncDiffusion: update the original latent
203
+ if view_idx != anchor_view_idx:
204
+ latent_view_copy = latent_view_copy - sync_scheduler[i] * norm_grad # 1 x 4 x 64 x 64
205
+ ############################## END: PERFORM GRADIENT DESCENT (SyncDiffusion) ##############################
206
+
207
+ # after gradient descent, perform a single denoising step
208
+ with torch.no_grad():
209
+ latent_model_input = torch.cat([latent_view_copy] * 2)
210
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']
211
+
212
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
213
+ noise_pred_new = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
214
+
215
+ out = self.scheduler.step(noise_pred_new, t, latent_view_copy)
216
+ latent_view_denoised = out['prev_sample']
217
+
218
+ # merge the latent views
219
+ value[:, :, h_start:h_end, w_start:w_end] += latent_view_denoised
220
+ count[:, :, h_start:h_end, w_start:w_end] += 1
221
+
222
+ # take the MultiDiffusion step (average the latents)
223
+ latent = torch.where(count > 0, value / count, value)
224
+
225
+ # decode latents to panorama image
226
+ with torch.no_grad():
227
+ imgs = self.decode_latents(latent) # [1, 3, 512, 512]
228
+ img = T.ToPILImage()(imgs[0].cpu())
229
+
230
+ print(f"[INFO] Done!")
231
+
232
+ return img
syncdiffusion/utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ def seed_everything(seed):
5
+ torch.manual_seed(seed)
6
+ torch.cuda.manual_seed(seed)
7
+ torch.cuda.manual_seed_all(seed)
8
+ torch.backends.cudnn.deterministic = True
9
+ torch.backends.cudnn.benchmark = False
10
+
11
+ def get_views(panorama_height, panorama_width, window_size=64, stride=8):
12
+ panorama_height /= 8
13
+ panorama_width /= 8
14
+ num_blocks_height = (panorama_height - window_size) // stride + 1
15
+ num_blocks_width = (panorama_width - window_size) // stride + 1
16
+ total_num_blocks = int(num_blocks_height * num_blocks_width)
17
+ views = []
18
+ for i in range(total_num_blocks):
19
+ h_start = int((i // num_blocks_width) * stride)
20
+ h_end = h_start + window_size
21
+ w_start = int((i % num_blocks_width) * stride)
22
+ w_end = w_start + window_size
23
+ views.append((h_start, h_end, w_start, w_end))
24
+ return views
25
+
26
+ def exponential_decay_list(init_weight, decay_rate, num_steps):
27
+ weights = [init_weight * (decay_rate ** i) for i in range(num_steps)]
28
+ return torch.tensor(weights)