phillipinseoul
commited on
Commit
•
47a3cb0
1
Parent(s):
31d5771
add app.py
Browse files- app.py +87 -0
- requirements.txt +9 -0
- syncdiffusion/__pycache__/model.cpython-38.pyc +0 -0
- syncdiffusion/__pycache__/syncdiffusion.cpython-38.pyc +0 -0
- syncdiffusion/__pycache__/syncdiffusion_model.cpython-38.pyc +0 -0
- syncdiffusion/__pycache__/syncdiffusion_model.cpython-39.pyc +0 -0
- syncdiffusion/__pycache__/utils.cpython-38.pyc +0 -0
- syncdiffusion/__pycache__/utils.cpython-39.pyc +0 -0
- syncdiffusion/syncdiffusion_model.py +232 -0
- syncdiffusion/utils.py +28 -0
app.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
app.py
|
3 |
+
An interactive demo for text-guided panorama generation.
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
from syncdiffusion.syncdiffusion_model import SyncDiffusion
|
9 |
+
from syncdiffusion.utils import seed_everything
|
10 |
+
|
11 |
+
def run_inference(
|
12 |
+
prompt: str,
|
13 |
+
height: int = 512,
|
14 |
+
width: int = 2048,
|
15 |
+
sync_weight: float = 20.0,
|
16 |
+
sync_decay_rate: float = 0.95,
|
17 |
+
sync_freq: int = 1,
|
18 |
+
sync_thres: int = 5,
|
19 |
+
seed: int = 0
|
20 |
+
):
|
21 |
+
# set device
|
22 |
+
device = torch.device = torch.device("cuda")
|
23 |
+
|
24 |
+
# set random seed
|
25 |
+
seed_everything(seed)
|
26 |
+
|
27 |
+
# load SyncDiffusion model
|
28 |
+
syncdiffusion = SyncDiffusion(device, sd_version="2.0")
|
29 |
+
|
30 |
+
img = syncdiffusion.sample_syncdiffusion(
|
31 |
+
prompts = prompt,
|
32 |
+
negative_prompts = "",
|
33 |
+
height = height,
|
34 |
+
width = width,
|
35 |
+
num_inference_steps = 50,
|
36 |
+
guidance_scale = 7.5,
|
37 |
+
sync_weight = sync_weight,
|
38 |
+
sync_decay_rate = sync_decay_rate,
|
39 |
+
sync_freq = sync_freq,
|
40 |
+
sync_thres = sync_thres,
|
41 |
+
stride = 16
|
42 |
+
)
|
43 |
+
return [img]
|
44 |
+
|
45 |
+
if __name__=="__main__":
|
46 |
+
title = "SyncDiffusion: Text-Guided Panorama Generation"
|
47 |
+
|
48 |
+
description_text = '''
|
49 |
+
This demo features text-guided panorama generation from our work <a href="https://arxiv.org/abs/2306.05178">SyncDiffusion: Coherent Montage via Synchronized Joint Diffusions, NeurIPS 2023</a>.
|
50 |
+
Please refer to our <a href="https://syncdiffusion.github.io/">project page</a> for details.
|
51 |
+
'''
|
52 |
+
|
53 |
+
# create UI
|
54 |
+
with gr.Blocks(title=title) as demo:
|
55 |
+
|
56 |
+
# description of demo
|
57 |
+
gr.Markdown(description_text)
|
58 |
+
|
59 |
+
# inputs
|
60 |
+
with gr.Row():
|
61 |
+
with gr.Column():
|
62 |
+
run_button = gr.Button(label="Generate")
|
63 |
+
|
64 |
+
prompt = gr.Textbox(label="Text Prompt", value='a cinematic view of a castle in the sunset')
|
65 |
+
width = gr.Slider(label="Width", minimum=512, maximum=4096, value=2048, step=128)
|
66 |
+
sync_weight = gr.Slider(label="Sync Weight", minimum=0.0, maximum=30.0, value=20.0, step=5.0)
|
67 |
+
sync_thres = gr.Slider(label="Sync Threshold (If N, apply SyncDiffusion for the first N steps)", minimum=0, maximum=50, value=5, step=1)
|
68 |
+
seed = gr.Number(label="Seed", value=0)
|
69 |
+
|
70 |
+
with gr.Column():
|
71 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
72 |
+
|
73 |
+
# display examples
|
74 |
+
examples = gr.Examples(
|
75 |
+
examples=[
|
76 |
+
'a cinematic view of a castle in the sunset',
|
77 |
+
'natural landscape in anime style illustration',
|
78 |
+
'a photo of a lake under the northern lights',
|
79 |
+
],
|
80 |
+
inputs=[prompt],
|
81 |
+
)
|
82 |
+
|
83 |
+
ips = [prompt, width, sync_weight, sync_thres, seed]
|
84 |
+
run_button.click(fn=run_inference(), inputs=ips, outputs=[result_gallery])
|
85 |
+
|
86 |
+
demo.queue(max_size=30)
|
87 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.12.1
|
2 |
+
torchvision==0.13.1
|
3 |
+
transformers==4.28
|
4 |
+
diffusers==0.15.1
|
5 |
+
opencv-python==4.5.1.48
|
6 |
+
lpips
|
7 |
+
accelerate
|
8 |
+
tqdm
|
9 |
+
gradio
|
syncdiffusion/__pycache__/model.cpython-38.pyc
ADDED
Binary file (5.81 kB). View file
|
|
syncdiffusion/__pycache__/syncdiffusion.cpython-38.pyc
ADDED
Binary file (5.75 kB). View file
|
|
syncdiffusion/__pycache__/syncdiffusion_model.cpython-38.pyc
ADDED
Binary file (5.53 kB). View file
|
|
syncdiffusion/__pycache__/syncdiffusion_model.cpython-39.pyc
ADDED
Binary file (5.63 kB). View file
|
|
syncdiffusion/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (1.3 kB). View file
|
|
syncdiffusion/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (1.29 kB). View file
|
|
syncdiffusion/syncdiffusion_model.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision.transforms as T
|
4 |
+
from torch.autograd import grad
|
5 |
+
import argparse
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
from syncdiffusion.utils import *
|
9 |
+
import lpips
|
10 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
11 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
|
12 |
+
|
13 |
+
class SyncDiffusion(nn.Module):
|
14 |
+
def __init__(self, device='cuda', sd_version='2.0', hf_key=None):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
self.device = device
|
18 |
+
self.sd_version = sd_version
|
19 |
+
|
20 |
+
print(f'[INFO] loading stable diffusion...')
|
21 |
+
if hf_key is not None:
|
22 |
+
print(f'[INFO] using hugging face custom model key: {hf_key}')
|
23 |
+
model_key = hf_key
|
24 |
+
elif self.sd_version == '2.1':
|
25 |
+
model_key = "stabilityai/stable-diffusion-2-1-base"
|
26 |
+
elif self.sd_version == '2.0':
|
27 |
+
model_key = "stabilityai/stable-diffusion-2-base"
|
28 |
+
elif self.sd_version == '1.5':
|
29 |
+
model_key = "runwayml/stable-diffusion-v1-5"
|
30 |
+
else:
|
31 |
+
raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
|
32 |
+
|
33 |
+
# Load pretrained models from HuggingFace
|
34 |
+
self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device)
|
35 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
|
36 |
+
self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
|
37 |
+
self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device)
|
38 |
+
|
39 |
+
# Freeze models
|
40 |
+
for p in self.unet.parameters():
|
41 |
+
p.requires_grad_(False)
|
42 |
+
for p in self.vae.parameters():
|
43 |
+
p.requires_grad_(False)
|
44 |
+
for p in self.text_encoder.parameters():
|
45 |
+
p.requires_grad_(False)
|
46 |
+
|
47 |
+
self.unet.eval()
|
48 |
+
self.vae.eval()
|
49 |
+
self.text_encoder.eval()
|
50 |
+
print(f'[INFO] loaded stable diffusion!')
|
51 |
+
|
52 |
+
# Set DDIM scheduler
|
53 |
+
self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
|
54 |
+
|
55 |
+
# load perceptual loss (LPIPS)
|
56 |
+
self.percept_loss = lpips.LPIPS(net='vgg').to(self.device)
|
57 |
+
print(f'[INFO] loaded perceptual loss!')
|
58 |
+
|
59 |
+
def get_text_embeds(self, prompt, negative_prompt):
|
60 |
+
# Tokenize text and get embeddings
|
61 |
+
text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
|
62 |
+
truncation=True, return_tensors='pt')
|
63 |
+
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
64 |
+
|
65 |
+
# Repeat for unconditional embeddings
|
66 |
+
uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
|
67 |
+
return_tensors='pt')
|
68 |
+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
69 |
+
|
70 |
+
# Concatenate for final embeddings
|
71 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
72 |
+
return text_embeddings
|
73 |
+
|
74 |
+
def decode_latents(self, latents):
|
75 |
+
latents = 1 / 0.18215 * latents
|
76 |
+
imgs = self.vae.decode(latents).sample
|
77 |
+
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
78 |
+
return imgs
|
79 |
+
|
80 |
+
def sample_syncdiffusion(
|
81 |
+
self,
|
82 |
+
prompts,
|
83 |
+
negative_prompts="",
|
84 |
+
height=512,
|
85 |
+
width=2048,
|
86 |
+
latent_size=64, # fix latent size to 64 for Stable Diffusion
|
87 |
+
num_inference_steps=50,
|
88 |
+
guidance_scale=7.5,
|
89 |
+
sync_weight=20, # gradient descent weight 'w' in the paper
|
90 |
+
sync_freq=1, # sync_freq=n: perform gradient descent every n steps
|
91 |
+
sync_thres=50, # sync_thres=n: compute SyncDiffusion only for the first n steps
|
92 |
+
sync_decay_rate=0.95, # decay rate for sync_weight, set as 0.95 in the paper
|
93 |
+
stride=16, # stride for latents, set as 16 in the paper
|
94 |
+
):
|
95 |
+
assert height >= 512 and width >= 512, 'height and width must be at least 512'
|
96 |
+
assert height % (stride * 8) == 0 and width % (stride * 8) == 0, 'height and width must be divisible by the stride multiplied by 8'
|
97 |
+
assert stride % 8 == 0 and stride < 64, 'stride must be divisible by 8 and smaller than the latent size of Stable Diffusion'
|
98 |
+
|
99 |
+
if isinstance(prompts, str):
|
100 |
+
prompts = [prompts]
|
101 |
+
|
102 |
+
if isinstance(negative_prompts, str):
|
103 |
+
negative_prompts = [negative_prompts]
|
104 |
+
|
105 |
+
# obtain text embeddings
|
106 |
+
text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2, 77, 768]
|
107 |
+
|
108 |
+
# define a list of windows to process in parallel
|
109 |
+
views = get_views(height, width, stride=stride)
|
110 |
+
print(f"[INFO] number of views to process: {len(views)}")
|
111 |
+
|
112 |
+
# Initialize latent
|
113 |
+
latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8))
|
114 |
+
|
115 |
+
count = torch.zeros_like(latent, requires_grad=False, device=self.device)
|
116 |
+
value = torch.zeros_like(latent, requires_grad=False, device=self.device)
|
117 |
+
latent = latent.to(self.device)
|
118 |
+
|
119 |
+
# set DDIM scheduler
|
120 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
121 |
+
|
122 |
+
# set the anchor view as the middle view
|
123 |
+
anchor_view_idx = len(views) // 2
|
124 |
+
|
125 |
+
# set SyncDiffusion scheduler
|
126 |
+
sync_scheduler = exponential_decay_list(
|
127 |
+
init_weight=sync_weight,
|
128 |
+
decay_rate=sync_decay_rate,
|
129 |
+
num_steps=num_inference_steps
|
130 |
+
)
|
131 |
+
print(f'[INFO] using exponential decay scheduler with decay rate {sync_decay_rate}')
|
132 |
+
|
133 |
+
with torch.autocast('cuda'):
|
134 |
+
for i, t in enumerate(tqdm(self.scheduler.timesteps)):
|
135 |
+
count.zero_()
|
136 |
+
value.zero_()
|
137 |
+
|
138 |
+
'''
|
139 |
+
(1) First, obtain the reference anchor view (for computing the perceptual loss)
|
140 |
+
'''
|
141 |
+
with torch.no_grad():
|
142 |
+
if (i + 1) % sync_freq == 0 and i < sync_thres:
|
143 |
+
# decode the anchor view
|
144 |
+
h_start, h_end, w_start, w_end = views[anchor_view_idx]
|
145 |
+
latent_view = latent[:, :, h_start:h_end, w_start:w_end].detach()
|
146 |
+
|
147 |
+
latent_model_input = torch.cat([latent_view] * 2) # 2 x 4 x 64 x 64
|
148 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']
|
149 |
+
|
150 |
+
# perform guidance
|
151 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
152 |
+
noise_pred_new = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
153 |
+
|
154 |
+
# predict the 'foreseen denoised' latent (x0) of the anchor view
|
155 |
+
latent_pred_x0 = self.scheduler.step(noise_pred_new, t, latent_view)["pred_original_sample"]
|
156 |
+
decoded_image_anchor = self.decode_latents(latent_pred_x0) # 1 x 3 x 512 x 512
|
157 |
+
|
158 |
+
'''
|
159 |
+
(2) Then perform SyncDiffusion and run a single denoising step
|
160 |
+
'''
|
161 |
+
for view_idx, (h_start, h_end, w_start, w_end) in enumerate(views):
|
162 |
+
latent_view = latent[:, :, h_start:h_end, w_start:w_end].detach()
|
163 |
+
|
164 |
+
############################## BEGIN: PERFORM GRADIENT DESCENT (SyncDiffusion) ##############################
|
165 |
+
latent_view_copy = latent_view.clone().detach()
|
166 |
+
|
167 |
+
#### TODO: TEST ####
|
168 |
+
# if i % sync_freq == 0 and i < sync_thres:
|
169 |
+
if (i + 1) % sync_freq == 0 and i < sync_thres:
|
170 |
+
|
171 |
+
# gradient on latent_view
|
172 |
+
latent_view = latent_view.requires_grad_()
|
173 |
+
|
174 |
+
# expand the latents for classifier-free guidance
|
175 |
+
latent_model_input = torch.cat([latent_view] * 2)
|
176 |
+
|
177 |
+
# predict the noise residual
|
178 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']
|
179 |
+
|
180 |
+
# perform guidance
|
181 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
182 |
+
noise_pred_new = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
183 |
+
|
184 |
+
# compute the denoising step with the reference model
|
185 |
+
out = self.scheduler.step(noise_pred_new, t, latent_view)
|
186 |
+
|
187 |
+
# predict the 'foreseen denoised' latent (x0)
|
188 |
+
latent_view_x0 = out['pred_original_sample']
|
189 |
+
|
190 |
+
# decode the denoised latent
|
191 |
+
decoded_x0 = self.decode_latents(latent_view_x0) # 1 x 3 x 512 x 512
|
192 |
+
|
193 |
+
# compute the perceptual loss (LPIPS)
|
194 |
+
percept_loss = self.percept_loss(
|
195 |
+
decoded_x0 * 2.0 - 1.0,
|
196 |
+
decoded_image_anchor * 2.0 - 1.0
|
197 |
+
)
|
198 |
+
|
199 |
+
# compute the gradient of the perceptual loss w.r.t. the latent
|
200 |
+
norm_grad = grad(outputs=percept_loss, inputs=latent_view)[0]
|
201 |
+
|
202 |
+
# SyncDiffusion: update the original latent
|
203 |
+
if view_idx != anchor_view_idx:
|
204 |
+
latent_view_copy = latent_view_copy - sync_scheduler[i] * norm_grad # 1 x 4 x 64 x 64
|
205 |
+
############################## END: PERFORM GRADIENT DESCENT (SyncDiffusion) ##############################
|
206 |
+
|
207 |
+
# after gradient descent, perform a single denoising step
|
208 |
+
with torch.no_grad():
|
209 |
+
latent_model_input = torch.cat([latent_view_copy] * 2)
|
210 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']
|
211 |
+
|
212 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
213 |
+
noise_pred_new = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
214 |
+
|
215 |
+
out = self.scheduler.step(noise_pred_new, t, latent_view_copy)
|
216 |
+
latent_view_denoised = out['prev_sample']
|
217 |
+
|
218 |
+
# merge the latent views
|
219 |
+
value[:, :, h_start:h_end, w_start:w_end] += latent_view_denoised
|
220 |
+
count[:, :, h_start:h_end, w_start:w_end] += 1
|
221 |
+
|
222 |
+
# take the MultiDiffusion step (average the latents)
|
223 |
+
latent = torch.where(count > 0, value / count, value)
|
224 |
+
|
225 |
+
# decode latents to panorama image
|
226 |
+
with torch.no_grad():
|
227 |
+
imgs = self.decode_latents(latent) # [1, 3, 512, 512]
|
228 |
+
img = T.ToPILImage()(imgs[0].cpu())
|
229 |
+
|
230 |
+
print(f"[INFO] Done!")
|
231 |
+
|
232 |
+
return img
|
syncdiffusion/utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
def seed_everything(seed):
|
5 |
+
torch.manual_seed(seed)
|
6 |
+
torch.cuda.manual_seed(seed)
|
7 |
+
torch.cuda.manual_seed_all(seed)
|
8 |
+
torch.backends.cudnn.deterministic = True
|
9 |
+
torch.backends.cudnn.benchmark = False
|
10 |
+
|
11 |
+
def get_views(panorama_height, panorama_width, window_size=64, stride=8):
|
12 |
+
panorama_height /= 8
|
13 |
+
panorama_width /= 8
|
14 |
+
num_blocks_height = (panorama_height - window_size) // stride + 1
|
15 |
+
num_blocks_width = (panorama_width - window_size) // stride + 1
|
16 |
+
total_num_blocks = int(num_blocks_height * num_blocks_width)
|
17 |
+
views = []
|
18 |
+
for i in range(total_num_blocks):
|
19 |
+
h_start = int((i // num_blocks_width) * stride)
|
20 |
+
h_end = h_start + window_size
|
21 |
+
w_start = int((i % num_blocks_width) * stride)
|
22 |
+
w_end = w_start + window_size
|
23 |
+
views.append((h_start, h_end, w_start, w_end))
|
24 |
+
return views
|
25 |
+
|
26 |
+
def exponential_decay_list(init_weight, decay_rate, num_steps):
|
27 |
+
weights = [init_weight * (decay_rate ** i) for i in range(num_steps)]
|
28 |
+
return torch.tensor(weights)
|