nikunjkdtechnoland commited on
Commit
4b98c85
1 Parent(s): e041d7d

some more add more files

Browse files
iopaint/file_manager/utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py
2
+ import hashlib
3
+ from pathlib import Path
4
+
5
+ from typing import Union
6
+
7
+
8
+ def generate_filename(directory: Path, original_filename, *options) -> str:
9
+ text = str(directory.absolute()) + original_filename
10
+ for v in options:
11
+ text += "%s" % v
12
+ md5_hash = hashlib.md5()
13
+ md5_hash.update(text.encode("utf-8"))
14
+ return md5_hash.hexdigest() + ".jpg"
15
+
16
+
17
+ def parse_size(size):
18
+ if isinstance(size, int):
19
+ # If the size parameter is a single number, assume square aspect.
20
+ return [size, size]
21
+
22
+ if isinstance(size, (tuple, list)):
23
+ if len(size) == 1:
24
+ # If single value tuple/list is provided, exand it to two elements
25
+ return size + type(size)(size)
26
+ return size
27
+
28
+ try:
29
+ thumbnail_size = [int(x) for x in size.lower().split("x", 1)]
30
+ except ValueError:
31
+ raise ValueError( # pylint: disable=raise-missing-from
32
+ "Bad thumbnail size format. Valid format is INTxINT."
33
+ )
34
+
35
+ if len(thumbnail_size) == 1:
36
+ # If the size parameter only contains a single integer, assume square aspect.
37
+ thumbnail_size.append(thumbnail_size[0])
38
+
39
+ return thumbnail_size
40
+
41
+
42
+ def aspect_to_string(size):
43
+ if isinstance(size, str):
44
+ return size
45
+
46
+ return "x".join(map(str, size))
47
+
48
+
49
+ IMG_SUFFIX = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"}
50
+
51
+
52
+ def glob_img(p: Union[Path, str], recursive: bool = False):
53
+ p = Path(p)
54
+ if p.is_file() and p.suffix in IMG_SUFFIX:
55
+ yield p
56
+ else:
57
+ if recursive:
58
+ files = Path(p).glob("**/*.*")
59
+ else:
60
+ files = Path(p).glob("*.*")
61
+
62
+ for it in files:
63
+ if it.suffix not in IMG_SUFFIX:
64
+ continue
65
+ yield it
iopaint/model/anytext/ldm/modules/diffusionmodules/upscaling.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from functools import partial
5
+
6
+ from iopaint.model.anytext.ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
7
+ from iopaint.model.anytext.ldm.util import default
8
+
9
+
10
+ class AbstractLowScaleModel(nn.Module):
11
+ # for concatenating a downsampled image to the latent representation
12
+ def __init__(self, noise_schedule_config=None):
13
+ super(AbstractLowScaleModel, self).__init__()
14
+ if noise_schedule_config is not None:
15
+ self.register_schedule(**noise_schedule_config)
16
+
17
+ def register_schedule(self, beta_schedule="linear", timesteps=1000,
18
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
19
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
20
+ cosine_s=cosine_s)
21
+ alphas = 1. - betas
22
+ alphas_cumprod = np.cumprod(alphas, axis=0)
23
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
24
+
25
+ timesteps, = betas.shape
26
+ self.num_timesteps = int(timesteps)
27
+ self.linear_start = linear_start
28
+ self.linear_end = linear_end
29
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
30
+
31
+ to_torch = partial(torch.tensor, dtype=torch.float32)
32
+
33
+ self.register_buffer('betas', to_torch(betas))
34
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
36
+
37
+ # calculations for diffusion q(x_t | x_{t-1}) and others
38
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
39
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
40
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
41
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
42
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
43
+
44
+ def q_sample(self, x_start, t, noise=None):
45
+ noise = default(noise, lambda: torch.randn_like(x_start))
46
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
47
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
48
+
49
+ def forward(self, x):
50
+ return x, None
51
+
52
+ def decode(self, x):
53
+ return x
54
+
55
+
56
+ class SimpleImageConcat(AbstractLowScaleModel):
57
+ # no noise level conditioning
58
+ def __init__(self):
59
+ super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
60
+ self.max_noise_level = 0
61
+
62
+ def forward(self, x):
63
+ # fix to constant noise level
64
+ return x, torch.zeros(x.shape[0], device=x.device).long()
65
+
66
+
67
+ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
68
+ def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
69
+ super().__init__(noise_schedule_config=noise_schedule_config)
70
+ self.max_noise_level = max_noise_level
71
+
72
+ def forward(self, x, noise_level=None):
73
+ if noise_level is None:
74
+ noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
75
+ else:
76
+ assert isinstance(noise_level, torch.Tensor)
77
+ z = self.q_sample(x, noise_level)
78
+ return z, noise_level
79
+
80
+
81
+
iopaint/model/anytext/ldm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+
18
+ from iopaint.model.anytext.ldm.util import instantiate_from_config
19
+
20
+
21
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22
+ if schedule == "linear":
23
+ betas = (
24
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25
+ )
26
+
27
+ elif schedule == "cosine":
28
+ timesteps = (
29
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30
+ )
31
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
32
+ alphas = torch.cos(alphas).pow(2)
33
+ alphas = alphas / alphas[0]
34
+ betas = 1 - alphas[1:] / alphas[:-1]
35
+ betas = np.clip(betas, a_min=0, a_max=0.999)
36
+
37
+ elif schedule == "sqrt_linear":
38
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39
+ elif schedule == "sqrt":
40
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41
+ else:
42
+ raise ValueError(f"schedule '{schedule}' unknown.")
43
+ return betas.numpy()
44
+
45
+
46
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47
+ if ddim_discr_method == 'uniform':
48
+ c = num_ddpm_timesteps // num_ddim_timesteps
49
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50
+ elif ddim_discr_method == 'quad':
51
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52
+ else:
53
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54
+
55
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
57
+ steps_out = ddim_timesteps + 1
58
+ if verbose:
59
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
60
+ return steps_out
61
+
62
+
63
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64
+ # select alphas for computing the variance schedule
65
+ alphas = alphacums[ddim_timesteps]
66
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67
+
68
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
69
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70
+ if verbose:
71
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72
+ print(f'For the chosen value of eta, which is {eta}, '
73
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74
+ return sigmas.to(torch.float32), alphas.to(torch.float32), alphas_prev.astype(np.float32)
75
+
76
+
77
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78
+ """
79
+ Create a beta schedule that discretizes the given alpha_t_bar function,
80
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
81
+ :param num_diffusion_timesteps: the number of betas to produce.
82
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83
+ produces the cumulative product of (1-beta) up to that
84
+ part of the diffusion process.
85
+ :param max_beta: the maximum beta to use; use values lower than 1 to
86
+ prevent singularities.
87
+ """
88
+ betas = []
89
+ for i in range(num_diffusion_timesteps):
90
+ t1 = i / num_diffusion_timesteps
91
+ t2 = (i + 1) / num_diffusion_timesteps
92
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93
+ return np.array(betas)
94
+
95
+
96
+ def extract_into_tensor(a, t, x_shape):
97
+ b, *_ = t.shape
98
+ out = a.gather(-1, t)
99
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100
+
101
+
102
+ def checkpoint(func, inputs, params, flag):
103
+ """
104
+ Evaluate a function without caching intermediate activations, allowing for
105
+ reduced memory at the expense of extra compute in the backward pass.
106
+ :param func: the function to evaluate.
107
+ :param inputs: the argument sequence to pass to `func`.
108
+ :param params: a sequence of parameters `func` depends on but does not
109
+ explicitly take as arguments.
110
+ :param flag: if False, disable gradient checkpointing.
111
+ """
112
+ if flag:
113
+ args = tuple(inputs) + tuple(params)
114
+ return CheckpointFunction.apply(func, len(inputs), *args)
115
+ else:
116
+ return func(*inputs)
117
+
118
+
119
+ class CheckpointFunction(torch.autograd.Function):
120
+ @staticmethod
121
+ def forward(ctx, run_function, length, *args):
122
+ ctx.run_function = run_function
123
+ ctx.input_tensors = list(args[:length])
124
+ ctx.input_params = list(args[length:])
125
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
126
+ "dtype": torch.get_autocast_gpu_dtype(),
127
+ "cache_enabled": torch.is_autocast_cache_enabled()}
128
+ with torch.no_grad():
129
+ output_tensors = ctx.run_function(*ctx.input_tensors)
130
+ return output_tensors
131
+
132
+ @staticmethod
133
+ def backward(ctx, *output_grads):
134
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
135
+ with torch.enable_grad(), \
136
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
137
+ # Fixes a bug where the first op in run_function modifies the
138
+ # Tensor storage in place, which is not allowed for detach()'d
139
+ # Tensors.
140
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
141
+ output_tensors = ctx.run_function(*shallow_copies)
142
+ input_grads = torch.autograd.grad(
143
+ output_tensors,
144
+ ctx.input_tensors + ctx.input_params,
145
+ output_grads,
146
+ allow_unused=True,
147
+ )
148
+ del ctx.input_tensors
149
+ del ctx.input_params
150
+ del output_tensors
151
+ return (None, None) + input_grads
152
+
153
+
154
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
155
+ """
156
+ Create sinusoidal timestep embeddings.
157
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
158
+ These may be fractional.
159
+ :param dim: the dimension of the output.
160
+ :param max_period: controls the minimum frequency of the embeddings.
161
+ :return: an [N x dim] Tensor of positional embeddings.
162
+ """
163
+ if not repeat_only:
164
+ half = dim // 2
165
+ freqs = torch.exp(
166
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
167
+ ).to(device=timesteps.device)
168
+ args = timesteps[:, None].float() * freqs[None]
169
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
170
+ if dim % 2:
171
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
172
+ else:
173
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
174
+ return embedding
175
+
176
+
177
+ def zero_module(module):
178
+ """
179
+ Zero out the parameters of a module and return it.
180
+ """
181
+ for p in module.parameters():
182
+ p.detach().zero_()
183
+ return module
184
+
185
+
186
+ def scale_module(module, scale):
187
+ """
188
+ Scale the parameters of a module and return it.
189
+ """
190
+ for p in module.parameters():
191
+ p.detach().mul_(scale)
192
+ return module
193
+
194
+
195
+ def mean_flat(tensor):
196
+ """
197
+ Take the mean over all non-batch dimensions.
198
+ """
199
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
200
+
201
+
202
+ def normalization(channels):
203
+ """
204
+ Make a standard normalization layer.
205
+ :param channels: number of input channels.
206
+ :return: an nn.Module for normalization.
207
+ """
208
+ return GroupNorm32(32, channels)
209
+
210
+
211
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
212
+ class SiLU(nn.Module):
213
+ def forward(self, x):
214
+ return x * torch.sigmoid(x)
215
+
216
+
217
+ class GroupNorm32(nn.GroupNorm):
218
+ def forward(self, x):
219
+ # return super().forward(x.float()).type(x.dtype)
220
+ return super().forward(x).type(x.dtype)
221
+
222
+ def conv_nd(dims, *args, **kwargs):
223
+ """
224
+ Create a 1D, 2D, or 3D convolution module.
225
+ """
226
+ if dims == 1:
227
+ return nn.Conv1d(*args, **kwargs)
228
+ elif dims == 2:
229
+ return nn.Conv2d(*args, **kwargs)
230
+ elif dims == 3:
231
+ return nn.Conv3d(*args, **kwargs)
232
+ raise ValueError(f"unsupported dimensions: {dims}")
233
+
234
+
235
+ def linear(*args, **kwargs):
236
+ """
237
+ Create a linear module.
238
+ """
239
+ return nn.Linear(*args, **kwargs)
240
+
241
+
242
+ def avg_pool_nd(dims, *args, **kwargs):
243
+ """
244
+ Create a 1D, 2D, or 3D average pooling module.
245
+ """
246
+ if dims == 1:
247
+ return nn.AvgPool1d(*args, **kwargs)
248
+ elif dims == 2:
249
+ return nn.AvgPool2d(*args, **kwargs)
250
+ elif dims == 3:
251
+ return nn.AvgPool3d(*args, **kwargs)
252
+ raise ValueError(f"unsupported dimensions: {dims}")
253
+
254
+
255
+ class HybridConditioner(nn.Module):
256
+
257
+ def __init__(self, c_concat_config, c_crossattn_config):
258
+ super().__init__()
259
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
260
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
261
+
262
+ def forward(self, c_concat, c_crossattn):
263
+ c_concat = self.concat_conditioner(c_concat)
264
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
265
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
266
+
267
+
268
+ def noise_like(shape, device, repeat=False):
269
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
270
+ noise = lambda: torch.randn(shape, device=device)
271
+ return repeat_noise() if repeat else noise()
iopaint/model/anytext/ldm/util.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ import torch
4
+ from torch import optim
5
+ import numpy as np
6
+
7
+ from inspect import isfunction
8
+ from PIL import Image, ImageDraw, ImageFont
9
+
10
+
11
+ def log_txt_as_img(wh, xc, size=10):
12
+ # wh a tuple of (width, height)
13
+ # xc a list of captions to plot
14
+ b = len(xc)
15
+ txts = list()
16
+ for bi in range(b):
17
+ txt = Image.new("RGB", wh, color="white")
18
+ draw = ImageDraw.Draw(txt)
19
+ font = ImageFont.truetype('font/Arial_Unicode.ttf', size=size)
20
+ nc = int(32 * (wh[0] / 256))
21
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
22
+
23
+ try:
24
+ draw.text((0, 0), lines, fill="black", font=font)
25
+ except UnicodeEncodeError:
26
+ print("Cant encode string for logging. Skipping.")
27
+
28
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
29
+ txts.append(txt)
30
+ txts = np.stack(txts)
31
+ txts = torch.tensor(txts)
32
+ return txts
33
+
34
+
35
+ def ismap(x):
36
+ if not isinstance(x, torch.Tensor):
37
+ return False
38
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
39
+
40
+
41
+ def isimage(x):
42
+ if not isinstance(x,torch.Tensor):
43
+ return False
44
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
45
+
46
+
47
+ def exists(x):
48
+ return x is not None
49
+
50
+
51
+ def default(val, d):
52
+ if exists(val):
53
+ return val
54
+ return d() if isfunction(d) else d
55
+
56
+
57
+ def mean_flat(tensor):
58
+ """
59
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
60
+ Take the mean over all non-batch dimensions.
61
+ """
62
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
63
+
64
+
65
+ def count_params(model, verbose=False):
66
+ total_params = sum(p.numel() for p in model.parameters())
67
+ if verbose:
68
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
69
+ return total_params
70
+
71
+
72
+ def instantiate_from_config(config, **kwargs):
73
+ if "target" not in config:
74
+ if config == '__is_first_stage__':
75
+ return None
76
+ elif config == "__is_unconditional__":
77
+ return None
78
+ raise KeyError("Expected key `target` to instantiate.")
79
+ return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
80
+
81
+
82
+ def get_obj_from_str(string, reload=False):
83
+ module, cls = string.rsplit(".", 1)
84
+ if reload:
85
+ module_imp = importlib.import_module(module)
86
+ importlib.reload(module_imp)
87
+ return getattr(importlib.import_module(module, package=None), cls)
88
+
89
+
90
+ class AdamWwithEMAandWings(optim.Optimizer):
91
+ # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
92
+ def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
93
+ weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
94
+ ema_power=1., param_names=()):
95
+ """AdamW that saves EMA versions of the parameters."""
96
+ if not 0.0 <= lr:
97
+ raise ValueError("Invalid learning rate: {}".format(lr))
98
+ if not 0.0 <= eps:
99
+ raise ValueError("Invalid epsilon value: {}".format(eps))
100
+ if not 0.0 <= betas[0] < 1.0:
101
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
102
+ if not 0.0 <= betas[1] < 1.0:
103
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
104
+ if not 0.0 <= weight_decay:
105
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
106
+ if not 0.0 <= ema_decay <= 1.0:
107
+ raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
108
+ defaults = dict(lr=lr, betas=betas, eps=eps,
109
+ weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
110
+ ema_power=ema_power, param_names=param_names)
111
+ super().__init__(params, defaults)
112
+
113
+ def __setstate__(self, state):
114
+ super().__setstate__(state)
115
+ for group in self.param_groups:
116
+ group.setdefault('amsgrad', False)
117
+
118
+ @torch.no_grad()
119
+ def step(self, closure=None):
120
+ """Performs a single optimization step.
121
+ Args:
122
+ closure (callable, optional): A closure that reevaluates the model
123
+ and returns the loss.
124
+ """
125
+ loss = None
126
+ if closure is not None:
127
+ with torch.enable_grad():
128
+ loss = closure()
129
+
130
+ for group in self.param_groups:
131
+ params_with_grad = []
132
+ grads = []
133
+ exp_avgs = []
134
+ exp_avg_sqs = []
135
+ ema_params_with_grad = []
136
+ state_sums = []
137
+ max_exp_avg_sqs = []
138
+ state_steps = []
139
+ amsgrad = group['amsgrad']
140
+ beta1, beta2 = group['betas']
141
+ ema_decay = group['ema_decay']
142
+ ema_power = group['ema_power']
143
+
144
+ for p in group['params']:
145
+ if p.grad is None:
146
+ continue
147
+ params_with_grad.append(p)
148
+ if p.grad.is_sparse:
149
+ raise RuntimeError('AdamW does not support sparse gradients')
150
+ grads.append(p.grad)
151
+
152
+ state = self.state[p]
153
+
154
+ # State initialization
155
+ if len(state) == 0:
156
+ state['step'] = 0
157
+ # Exponential moving average of gradient values
158
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
159
+ # Exponential moving average of squared gradient values
160
+ state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
161
+ if amsgrad:
162
+ # Maintains max of all exp. moving avg. of sq. grad. values
163
+ state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
164
+ # Exponential moving average of parameter values
165
+ state['param_exp_avg'] = p.detach().float().clone()
166
+
167
+ exp_avgs.append(state['exp_avg'])
168
+ exp_avg_sqs.append(state['exp_avg_sq'])
169
+ ema_params_with_grad.append(state['param_exp_avg'])
170
+
171
+ if amsgrad:
172
+ max_exp_avg_sqs.append(state['max_exp_avg_sq'])
173
+
174
+ # update the steps for each param group update
175
+ state['step'] += 1
176
+ # record the step after step update
177
+ state_steps.append(state['step'])
178
+
179
+ optim._functional.adamw(params_with_grad,
180
+ grads,
181
+ exp_avgs,
182
+ exp_avg_sqs,
183
+ max_exp_avg_sqs,
184
+ state_steps,
185
+ amsgrad=amsgrad,
186
+ beta1=beta1,
187
+ beta2=beta2,
188
+ lr=group['lr'],
189
+ weight_decay=group['weight_decay'],
190
+ eps=group['eps'],
191
+ maximize=False)
192
+
193
+ cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
194
+ for param, ema_param in zip(params_with_grad, ema_params_with_grad):
195
+ ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
196
+
197
+ return loss
iopaint/model/anytext/utils.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import datetime
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image, ImageDraw
6
+
7
+
8
+ def save_images(img_list, folder):
9
+ if not os.path.exists(folder):
10
+ os.makedirs(folder)
11
+ now = datetime.datetime.now()
12
+ date_str = now.strftime("%Y-%m-%d")
13
+ folder_path = os.path.join(folder, date_str)
14
+ if not os.path.exists(folder_path):
15
+ os.makedirs(folder_path)
16
+ time_str = now.strftime("%H_%M_%S")
17
+ for idx, img in enumerate(img_list):
18
+ image_number = idx + 1
19
+ filename = f"{time_str}_{image_number}.jpg"
20
+ save_path = os.path.join(folder_path, filename)
21
+ cv2.imwrite(save_path, img[..., ::-1])
22
+
23
+
24
+ def check_channels(image):
25
+ channels = image.shape[2] if len(image.shape) == 3 else 1
26
+ if channels == 1:
27
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
28
+ elif channels > 3:
29
+ image = image[:, :, :3]
30
+ return image
31
+
32
+
33
+ def resize_image(img, max_length=768):
34
+ height, width = img.shape[:2]
35
+ max_dimension = max(height, width)
36
+
37
+ if max_dimension > max_length:
38
+ scale_factor = max_length / max_dimension
39
+ new_width = int(round(width * scale_factor))
40
+ new_height = int(round(height * scale_factor))
41
+ new_size = (new_width, new_height)
42
+ img = cv2.resize(img, new_size)
43
+ height, width = img.shape[:2]
44
+ img = cv2.resize(img, (width - (width % 64), height - (height % 64)))
45
+ return img
46
+
47
+
48
+ def insert_spaces(string, nSpace):
49
+ if nSpace == 0:
50
+ return string
51
+ new_string = ""
52
+ for char in string:
53
+ new_string += char + " " * nSpace
54
+ return new_string[:-nSpace]
55
+
56
+
57
+ def draw_glyph(font, text):
58
+ g_size = 50
59
+ W, H = (512, 80)
60
+ new_font = font.font_variant(size=g_size)
61
+ img = Image.new(mode="1", size=(W, H), color=0)
62
+ draw = ImageDraw.Draw(img)
63
+ left, top, right, bottom = new_font.getbbox(text)
64
+ text_width = max(right - left, 5)
65
+ text_height = max(bottom - top, 5)
66
+ ratio = min(W * 0.9 / text_width, H * 0.9 / text_height)
67
+ new_font = font.font_variant(size=int(g_size * ratio))
68
+
69
+ text_width, text_height = new_font.getsize(text)
70
+ offset_x, offset_y = new_font.getoffset(text)
71
+ x = (img.width - text_width) // 2
72
+ y = (img.height - text_height) // 2 - offset_y // 2
73
+ draw.text((x, y), text, font=new_font, fill="white")
74
+ img = np.expand_dims(np.array(img), axis=2).astype(np.float64)
75
+ return img
76
+
77
+
78
+ def draw_glyph2(
79
+ font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True
80
+ ):
81
+ enlarge_polygon = polygon * scale
82
+ rect = cv2.minAreaRect(enlarge_polygon)
83
+ box = cv2.boxPoints(rect)
84
+ box = np.int0(box)
85
+ w, h = rect[1]
86
+ angle = rect[2]
87
+ if angle < -45:
88
+ angle += 90
89
+ angle = -angle
90
+ if w < h:
91
+ angle += 90
92
+
93
+ vert = False
94
+ if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng:
95
+ _w = max(box[:, 0]) - min(box[:, 0])
96
+ _h = max(box[:, 1]) - min(box[:, 1])
97
+ if _h >= _w:
98
+ vert = True
99
+ angle = 0
100
+
101
+ img = np.zeros((height * scale, width * scale, 3), np.uint8)
102
+ img = Image.fromarray(img)
103
+
104
+ # infer font size
105
+ image4ratio = Image.new("RGB", img.size, "white")
106
+ draw = ImageDraw.Draw(image4ratio)
107
+ _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font)
108
+ text_w = min(w, h) * (_tw / _th)
109
+ if text_w <= max(w, h):
110
+ # add space
111
+ if len(text) > 1 and not vert and add_space:
112
+ for i in range(1, 100):
113
+ text_space = insert_spaces(text, i)
114
+ _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font)
115
+ if min(w, h) * (_tw2 / _th2) > max(w, h):
116
+ break
117
+ text = insert_spaces(text, i - 1)
118
+ font_size = min(w, h) * 0.80
119
+ else:
120
+ shrink = 0.75 if vert else 0.85
121
+ font_size = min(w, h) / (text_w / max(w, h)) * shrink
122
+ new_font = font.font_variant(size=int(font_size))
123
+
124
+ left, top, right, bottom = new_font.getbbox(text)
125
+ text_width = right - left
126
+ text_height = bottom - top
127
+
128
+ layer = Image.new("RGBA", img.size, (0, 0, 0, 0))
129
+ draw = ImageDraw.Draw(layer)
130
+ if not vert:
131
+ draw.text(
132
+ (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top),
133
+ text,
134
+ font=new_font,
135
+ fill=(255, 255, 255, 255),
136
+ )
137
+ else:
138
+ x_s = min(box[:, 0]) + _w // 2 - text_height // 2
139
+ y_s = min(box[:, 1])
140
+ for c in text:
141
+ draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255))
142
+ _, _t, _, _b = new_font.getbbox(c)
143
+ y_s += _b
144
+
145
+ rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1]))
146
+
147
+ x_offset = int((img.width - rotated_layer.width) / 2)
148
+ y_offset = int((img.height - rotated_layer.height) / 2)
149
+ img.paste(rotated_layer, (x_offset, y_offset), rotated_layer)
150
+ img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64)
151
+ return img
iopaint/model/original_sd_configs/v1-inference.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 10000 ]
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
iopaint/model/original_sd_configs/v2-inference-v.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ parameterization: "v"
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "jpg"
12
+ cond_stage_key: "txt"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: false
16
+ conditioning_key: crossattn
17
+ monitor: val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ use_ema: False # we set this to false because this is an inference only config
20
+
21
+ unet_config:
22
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
23
+ params:
24
+ use_checkpoint: True
25
+ use_fp16: True
26
+ image_size: 32 # unused
27
+ in_channels: 4
28
+ out_channels: 4
29
+ model_channels: 320
30
+ attention_resolutions: [ 4, 2, 1 ]
31
+ num_res_blocks: 2
32
+ channel_mult: [ 1, 2, 4, 4 ]
33
+ num_head_channels: 64 # need to fix for flash-attn
34
+ use_spatial_transformer: True
35
+ use_linear_in_transformer: True
36
+ transformer_depth: 1
37
+ context_dim: 1024
38
+ legacy: False
39
+
40
+ first_stage_config:
41
+ target: ldm.models.autoencoder.AutoencoderKL
42
+ params:
43
+ embed_dim: 4
44
+ monitor: val/rec_loss
45
+ ddconfig:
46
+ #attn_type: "vanilla-xformers"
47
+ double_z: true
48
+ z_channels: 4
49
+ resolution: 256
50
+ in_channels: 3
51
+ out_ch: 3
52
+ ch: 128
53
+ ch_mult:
54
+ - 1
55
+ - 2
56
+ - 4
57
+ - 4
58
+ num_res_blocks: 2
59
+ attn_resolutions: []
60
+ dropout: 0.0
61
+ lossconfig:
62
+ target: torch.nn.Identity
63
+
64
+ cond_stage_config:
65
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
66
+ params:
67
+ freeze: True
68
+ layer: "penultimate"
iopaint/model/utils.py ADDED
@@ -0,0 +1,1033 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import random
4
+ import traceback
5
+ from typing import Any
6
+
7
+ import torch
8
+ import numpy as np
9
+ import collections
10
+ from itertools import repeat
11
+
12
+ from diffusers import (
13
+ DDIMScheduler,
14
+ PNDMScheduler,
15
+ LMSDiscreteScheduler,
16
+ EulerDiscreteScheduler,
17
+ EulerAncestralDiscreteScheduler,
18
+ DPMSolverMultistepScheduler,
19
+ UniPCMultistepScheduler,
20
+ LCMScheduler,
21
+ DPMSolverSinglestepScheduler,
22
+ KDPM2DiscreteScheduler,
23
+ KDPM2AncestralDiscreteScheduler,
24
+ HeunDiscreteScheduler,
25
+ )
26
+ from loguru import logger
27
+
28
+ from iopaint.schema import SDSampler
29
+ from torch import conv2d, conv_transpose2d
30
+
31
+
32
+ def make_beta_schedule(
33
+ device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
34
+ ):
35
+ if schedule == "linear":
36
+ betas = (
37
+ torch.linspace(
38
+ linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
39
+ )
40
+ ** 2
41
+ )
42
+
43
+ elif schedule == "cosine":
44
+ timesteps = (
45
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
46
+ ).to(device)
47
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
48
+ alphas = torch.cos(alphas).pow(2).to(device)
49
+ alphas = alphas / alphas[0]
50
+ betas = 1 - alphas[1:] / alphas[:-1]
51
+ betas = np.clip(betas, a_min=0, a_max=0.999)
52
+
53
+ elif schedule == "sqrt_linear":
54
+ betas = torch.linspace(
55
+ linear_start, linear_end, n_timestep, dtype=torch.float64
56
+ )
57
+ elif schedule == "sqrt":
58
+ betas = (
59
+ torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
60
+ ** 0.5
61
+ )
62
+ else:
63
+ raise ValueError(f"schedule '{schedule}' unknown.")
64
+ return betas.numpy()
65
+
66
+
67
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
68
+ # select alphas for computing the variance schedule
69
+ alphas = alphacums[ddim_timesteps]
70
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
71
+
72
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
73
+ sigmas = eta * np.sqrt(
74
+ (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
75
+ )
76
+ if verbose:
77
+ print(
78
+ f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
79
+ )
80
+ print(
81
+ f"For the chosen value of eta, which is {eta}, "
82
+ f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
83
+ )
84
+ return sigmas, alphas, alphas_prev
85
+
86
+
87
+ def make_ddim_timesteps(
88
+ ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
89
+ ):
90
+ if ddim_discr_method == "uniform":
91
+ c = num_ddpm_timesteps // num_ddim_timesteps
92
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
93
+ elif ddim_discr_method == "quad":
94
+ ddim_timesteps = (
95
+ (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
96
+ ).astype(int)
97
+ else:
98
+ raise NotImplementedError(
99
+ f'There is no ddim discretization method called "{ddim_discr_method}"'
100
+ )
101
+
102
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
103
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
104
+ steps_out = ddim_timesteps + 1
105
+ if verbose:
106
+ print(f"Selected timesteps for ddim sampler: {steps_out}")
107
+ return steps_out
108
+
109
+
110
+ def noise_like(shape, device, repeat=False):
111
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
112
+ shape[0], *((1,) * (len(shape) - 1))
113
+ )
114
+ noise = lambda: torch.randn(shape, device=device)
115
+ return repeat_noise() if repeat else noise()
116
+
117
+
118
+ def timestep_embedding(device, timesteps, dim, max_period=10000, repeat_only=False):
119
+ """
120
+ Create sinusoidal timestep embeddings.
121
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
122
+ These may be fractional.
123
+ :param dim: the dimension of the output.
124
+ :param max_period: controls the minimum frequency of the embeddings.
125
+ :return: an [N x dim] Tensor of positional embeddings.
126
+ """
127
+ half = dim // 2
128
+ freqs = torch.exp(
129
+ -math.log(max_period)
130
+ * torch.arange(start=0, end=half, dtype=torch.float32)
131
+ / half
132
+ ).to(device=device)
133
+
134
+ args = timesteps[:, None].float() * freqs[None]
135
+
136
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
137
+ if dim % 2:
138
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
139
+ return embedding
140
+
141
+
142
+ ###### MAT and FcF #######
143
+
144
+
145
+ def normalize_2nd_moment(x, dim=1):
146
+ return (
147
+ x * (x.square().mean(dim=dim, keepdim=True) + torch.finfo(x.dtype).eps).rsqrt()
148
+ )
149
+
150
+
151
+ class EasyDict(dict):
152
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
153
+
154
+ def __getattr__(self, name: str) -> Any:
155
+ try:
156
+ return self[name]
157
+ except KeyError:
158
+ raise AttributeError(name)
159
+
160
+ def __setattr__(self, name: str, value: Any) -> None:
161
+ self[name] = value
162
+
163
+ def __delattr__(self, name: str) -> None:
164
+ del self[name]
165
+
166
+
167
+ def _bias_act_ref(x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None):
168
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops."""
169
+ assert isinstance(x, torch.Tensor)
170
+ assert clamp is None or clamp >= 0
171
+ spec = activation_funcs[act]
172
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
173
+ gain = float(gain if gain is not None else spec.def_gain)
174
+ clamp = float(clamp if clamp is not None else -1)
175
+
176
+ # Add bias.
177
+ if b is not None:
178
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
179
+ assert 0 <= dim < x.ndim
180
+ assert b.shape[0] == x.shape[dim]
181
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
182
+
183
+ # Evaluate activation function.
184
+ alpha = float(alpha)
185
+ x = spec.func(x, alpha=alpha)
186
+
187
+ # Scale by gain.
188
+ gain = float(gain)
189
+ if gain != 1:
190
+ x = x * gain
191
+
192
+ # Clamp.
193
+ if clamp >= 0:
194
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
195
+ return x
196
+
197
+
198
+ def bias_act(
199
+ x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None, impl="ref"
200
+ ):
201
+ r"""Fused bias and activation function.
202
+
203
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
204
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
205
+ the fused op is considerably more efficient than performing the same calculation
206
+ using standard PyTorch ops. It supports first and second order gradients,
207
+ but not third order gradients.
208
+
209
+ Args:
210
+ x: Input activation tensor. Can be of any shape.
211
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
212
+ as `x`. The shape must be known, and it must match the dimension of `x`
213
+ corresponding to `dim`.
214
+ dim: The dimension in `x` corresponding to the elements of `b`.
215
+ The value of `dim` is ignored if `b` is not specified.
216
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
217
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
218
+ See `activation_funcs` for a full list. `None` is not allowed.
219
+ alpha: Shape parameter for the activation function, or `None` to use the default.
220
+ gain: Scaling factor for the output tensor, or `None` to use default.
221
+ See `activation_funcs` for the default scaling of each activation function.
222
+ If unsure, consider specifying 1.
223
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
224
+ the clamping (default).
225
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
226
+
227
+ Returns:
228
+ Tensor of the same shape and datatype as `x`.
229
+ """
230
+ assert isinstance(x, torch.Tensor)
231
+ assert impl in ["ref", "cuda"]
232
+ return _bias_act_ref(
233
+ x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp
234
+ )
235
+
236
+
237
+ def _get_filter_size(f):
238
+ if f is None:
239
+ return 1, 1
240
+
241
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
242
+ fw = f.shape[-1]
243
+ fh = f.shape[0]
244
+
245
+ fw = int(fw)
246
+ fh = int(fh)
247
+ assert fw >= 1 and fh >= 1
248
+ return fw, fh
249
+
250
+
251
+ def _get_weight_shape(w):
252
+ shape = [int(sz) for sz in w.shape]
253
+ return shape
254
+
255
+
256
+ def _parse_scaling(scaling):
257
+ if isinstance(scaling, int):
258
+ scaling = [scaling, scaling]
259
+ assert isinstance(scaling, (list, tuple))
260
+ assert all(isinstance(x, int) for x in scaling)
261
+ sx, sy = scaling
262
+ assert sx >= 1 and sy >= 1
263
+ return sx, sy
264
+
265
+
266
+ def _parse_padding(padding):
267
+ if isinstance(padding, int):
268
+ padding = [padding, padding]
269
+ assert isinstance(padding, (list, tuple))
270
+ assert all(isinstance(x, int) for x in padding)
271
+ if len(padding) == 2:
272
+ padx, pady = padding
273
+ padding = [padx, padx, pady, pady]
274
+ padx0, padx1, pady0, pady1 = padding
275
+ return padx0, padx1, pady0, pady1
276
+
277
+
278
+ def setup_filter(
279
+ f,
280
+ device=torch.device("cpu"),
281
+ normalize=True,
282
+ flip_filter=False,
283
+ gain=1,
284
+ separable=None,
285
+ ):
286
+ r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
287
+
288
+ Args:
289
+ f: Torch tensor, numpy array, or python list of the shape
290
+ `[filter_height, filter_width]` (non-separable),
291
+ `[filter_taps]` (separable),
292
+ `[]` (impulse), or
293
+ `None` (identity).
294
+ device: Result device (default: cpu).
295
+ normalize: Normalize the filter so that it retains the magnitude
296
+ for constant input signal (DC)? (default: True).
297
+ flip_filter: Flip the filter? (default: False).
298
+ gain: Overall scaling factor for signal magnitude (default: 1).
299
+ separable: Return a separable filter? (default: select automatically).
300
+
301
+ Returns:
302
+ Float32 tensor of the shape
303
+ `[filter_height, filter_width]` (non-separable) or
304
+ `[filter_taps]` (separable).
305
+ """
306
+ # Validate.
307
+ if f is None:
308
+ f = 1
309
+ f = torch.as_tensor(f, dtype=torch.float32)
310
+ assert f.ndim in [0, 1, 2]
311
+ assert f.numel() > 0
312
+ if f.ndim == 0:
313
+ f = f[np.newaxis]
314
+
315
+ # Separable?
316
+ if separable is None:
317
+ separable = f.ndim == 1 and f.numel() >= 8
318
+ if f.ndim == 1 and not separable:
319
+ f = f.ger(f)
320
+ assert f.ndim == (1 if separable else 2)
321
+
322
+ # Apply normalize, flip, gain, and device.
323
+ if normalize:
324
+ f /= f.sum()
325
+ if flip_filter:
326
+ f = f.flip(list(range(f.ndim)))
327
+ f = f * (gain ** (f.ndim / 2))
328
+ f = f.to(device=device)
329
+ return f
330
+
331
+
332
+ def _ntuple(n):
333
+ def parse(x):
334
+ if isinstance(x, collections.abc.Iterable):
335
+ return x
336
+ return tuple(repeat(x, n))
337
+
338
+ return parse
339
+
340
+
341
+ to_2tuple = _ntuple(2)
342
+
343
+ activation_funcs = {
344
+ "linear": EasyDict(
345
+ func=lambda x, **_: x,
346
+ def_alpha=0,
347
+ def_gain=1,
348
+ cuda_idx=1,
349
+ ref="",
350
+ has_2nd_grad=False,
351
+ ),
352
+ "relu": EasyDict(
353
+ func=lambda x, **_: torch.nn.functional.relu(x),
354
+ def_alpha=0,
355
+ def_gain=np.sqrt(2),
356
+ cuda_idx=2,
357
+ ref="y",
358
+ has_2nd_grad=False,
359
+ ),
360
+ "lrelu": EasyDict(
361
+ func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
362
+ def_alpha=0.2,
363
+ def_gain=np.sqrt(2),
364
+ cuda_idx=3,
365
+ ref="y",
366
+ has_2nd_grad=False,
367
+ ),
368
+ "tanh": EasyDict(
369
+ func=lambda x, **_: torch.tanh(x),
370
+ def_alpha=0,
371
+ def_gain=1,
372
+ cuda_idx=4,
373
+ ref="y",
374
+ has_2nd_grad=True,
375
+ ),
376
+ "sigmoid": EasyDict(
377
+ func=lambda x, **_: torch.sigmoid(x),
378
+ def_alpha=0,
379
+ def_gain=1,
380
+ cuda_idx=5,
381
+ ref="y",
382
+ has_2nd_grad=True,
383
+ ),
384
+ "elu": EasyDict(
385
+ func=lambda x, **_: torch.nn.functional.elu(x),
386
+ def_alpha=0,
387
+ def_gain=1,
388
+ cuda_idx=6,
389
+ ref="y",
390
+ has_2nd_grad=True,
391
+ ),
392
+ "selu": EasyDict(
393
+ func=lambda x, **_: torch.nn.functional.selu(x),
394
+ def_alpha=0,
395
+ def_gain=1,
396
+ cuda_idx=7,
397
+ ref="y",
398
+ has_2nd_grad=True,
399
+ ),
400
+ "softplus": EasyDict(
401
+ func=lambda x, **_: torch.nn.functional.softplus(x),
402
+ def_alpha=0,
403
+ def_gain=1,
404
+ cuda_idx=8,
405
+ ref="y",
406
+ has_2nd_grad=True,
407
+ ),
408
+ "swish": EasyDict(
409
+ func=lambda x, **_: torch.sigmoid(x) * x,
410
+ def_alpha=0,
411
+ def_gain=np.sqrt(2),
412
+ cuda_idx=9,
413
+ ref="x",
414
+ has_2nd_grad=True,
415
+ ),
416
+ }
417
+
418
+
419
+ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"):
420
+ r"""Pad, upsample, filter, and downsample a batch of 2D images.
421
+
422
+ Performs the following sequence of operations for each channel:
423
+
424
+ 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
425
+
426
+ 2. Pad the image with the specified number of zeros on each side (`padding`).
427
+ Negative padding corresponds to cropping the image.
428
+
429
+ 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
430
+ so that the footprint of all output pixels lies within the input image.
431
+
432
+ 4. Downsample the image by keeping every Nth pixel (`down`).
433
+
434
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
435
+ The fused op is considerably more efficient than performing the same calculation
436
+ using standard PyTorch ops. It supports gradients of arbitrary order.
437
+
438
+ Args:
439
+ x: Float32/float64/float16 input tensor of the shape
440
+ `[batch_size, num_channels, in_height, in_width]`.
441
+ f: Float32 FIR filter of the shape
442
+ `[filter_height, filter_width]` (non-separable),
443
+ `[filter_taps]` (separable), or
444
+ `None` (identity).
445
+ up: Integer upsampling factor. Can be a single int or a list/tuple
446
+ `[x, y]` (default: 1).
447
+ down: Integer downsampling factor. Can be a single int or a list/tuple
448
+ `[x, y]` (default: 1).
449
+ padding: Padding with respect to the upsampled image. Can be a single number
450
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
451
+ (default: 0).
452
+ flip_filter: False = convolution, True = correlation (default: False).
453
+ gain: Overall scaling factor for signal magnitude (default: 1).
454
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
455
+
456
+ Returns:
457
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
458
+ """
459
+ # assert isinstance(x, torch.Tensor)
460
+ # assert impl in ['ref', 'cuda']
461
+ return _upfirdn2d_ref(
462
+ x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain
463
+ )
464
+
465
+
466
+ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
467
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops."""
468
+ # Validate arguments.
469
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
470
+ if f is None:
471
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
472
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
473
+ assert not f.requires_grad
474
+ batch_size, num_channels, in_height, in_width = x.shape
475
+ # upx, upy = _parse_scaling(up)
476
+ # downx, downy = _parse_scaling(down)
477
+
478
+ upx, upy = up, up
479
+ downx, downy = down, down
480
+
481
+ # padx0, padx1, pady0, pady1 = _parse_padding(padding)
482
+ padx0, padx1, pady0, pady1 = padding[0], padding[1], padding[2], padding[3]
483
+
484
+ # Upsample by inserting zeros.
485
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
486
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
487
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
488
+
489
+ # Pad or crop.
490
+ x = torch.nn.functional.pad(
491
+ x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]
492
+ )
493
+ x = x[
494
+ :,
495
+ :,
496
+ max(-pady0, 0) : x.shape[2] - max(-pady1, 0),
497
+ max(-padx0, 0) : x.shape[3] - max(-padx1, 0),
498
+ ]
499
+
500
+ # Setup filter.
501
+ f = f * (gain ** (f.ndim / 2))
502
+ f = f.to(x.dtype)
503
+ if not flip_filter:
504
+ f = f.flip(list(range(f.ndim)))
505
+
506
+ # Convolve with the filter.
507
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
508
+ if f.ndim == 4:
509
+ x = conv2d(input=x, weight=f, groups=num_channels)
510
+ else:
511
+ x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
512
+ x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
513
+
514
+ # Downsample by throwing away pixels.
515
+ x = x[:, :, ::downy, ::downx]
516
+ return x
517
+
518
+
519
+ def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl="cuda"):
520
+ r"""Downsample a batch of 2D images using the given 2D FIR filter.
521
+
522
+ By default, the result is padded so that its shape is a fraction of the input.
523
+ User-specified padding is applied on top of that, with negative values
524
+ indicating cropping. Pixels outside the image are assumed to be zero.
525
+
526
+ Args:
527
+ x: Float32/float64/float16 input tensor of the shape
528
+ `[batch_size, num_channels, in_height, in_width]`.
529
+ f: Float32 FIR filter of the shape
530
+ `[filter_height, filter_width]` (non-separable),
531
+ `[filter_taps]` (separable), or
532
+ `None` (identity).
533
+ down: Integer downsampling factor. Can be a single int or a list/tuple
534
+ `[x, y]` (default: 1).
535
+ padding: Padding with respect to the input. Can be a single number or a
536
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
537
+ (default: 0).
538
+ flip_filter: False = convolution, True = correlation (default: False).
539
+ gain: Overall scaling factor for signal magnitude (default: 1).
540
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
541
+
542
+ Returns:
543
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
544
+ """
545
+ downx, downy = _parse_scaling(down)
546
+ # padx0, padx1, pady0, pady1 = _parse_padding(padding)
547
+ padx0, padx1, pady0, pady1 = padding, padding, padding, padding
548
+
549
+ fw, fh = _get_filter_size(f)
550
+ p = [
551
+ padx0 + (fw - downx + 1) // 2,
552
+ padx1 + (fw - downx) // 2,
553
+ pady0 + (fh - downy + 1) // 2,
554
+ pady1 + (fh - downy) // 2,
555
+ ]
556
+ return upfirdn2d(
557
+ x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl
558
+ )
559
+
560
+
561
+ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl="cuda"):
562
+ r"""Upsample a batch of 2D images using the given 2D FIR filter.
563
+
564
+ By default, the result is padded so that its shape is a multiple of the input.
565
+ User-specified padding is applied on top of that, with negative values
566
+ indicating cropping. Pixels outside the image are assumed to be zero.
567
+
568
+ Args:
569
+ x: Float32/float64/float16 input tensor of the shape
570
+ `[batch_size, num_channels, in_height, in_width]`.
571
+ f: Float32 FIR filter of the shape
572
+ `[filter_height, filter_width]` (non-separable),
573
+ `[filter_taps]` (separable), or
574
+ `None` (identity).
575
+ up: Integer upsampling factor. Can be a single int or a list/tuple
576
+ `[x, y]` (default: 1).
577
+ padding: Padding with respect to the output. Can be a single number or a
578
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
579
+ (default: 0).
580
+ flip_filter: False = convolution, True = correlation (default: False).
581
+ gain: Overall scaling factor for signal magnitude (default: 1).
582
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
583
+
584
+ Returns:
585
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
586
+ """
587
+ upx, upy = _parse_scaling(up)
588
+ # upx, upy = up, up
589
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
590
+ # padx0, padx1, pady0, pady1 = padding, padding, padding, padding
591
+ fw, fh = _get_filter_size(f)
592
+ p = [
593
+ padx0 + (fw + upx - 1) // 2,
594
+ padx1 + (fw - upx) // 2,
595
+ pady0 + (fh + upy - 1) // 2,
596
+ pady1 + (fh - upy) // 2,
597
+ ]
598
+ return upfirdn2d(
599
+ x,
600
+ f,
601
+ up=up,
602
+ padding=p,
603
+ flip_filter=flip_filter,
604
+ gain=gain * upx * upy,
605
+ impl=impl,
606
+ )
607
+
608
+
609
+ class MinibatchStdLayer(torch.nn.Module):
610
+ def __init__(self, group_size, num_channels=1):
611
+ super().__init__()
612
+ self.group_size = group_size
613
+ self.num_channels = num_channels
614
+
615
+ def forward(self, x):
616
+ N, C, H, W = x.shape
617
+ G = (
618
+ torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N))
619
+ if self.group_size is not None
620
+ else N
621
+ )
622
+ F = self.num_channels
623
+ c = C // F
624
+
625
+ y = x.reshape(
626
+ G, -1, F, c, H, W
627
+ ) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
628
+ y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
629
+ y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
630
+ y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
631
+ y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels.
632
+ y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
633
+ y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
634
+ x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
635
+ return x
636
+
637
+
638
+ class FullyConnectedLayer(torch.nn.Module):
639
+ def __init__(
640
+ self,
641
+ in_features, # Number of input features.
642
+ out_features, # Number of output features.
643
+ bias=True, # Apply additive bias before the activation function?
644
+ activation="linear", # Activation function: 'relu', 'lrelu', etc.
645
+ lr_multiplier=1, # Learning rate multiplier.
646
+ bias_init=0, # Initial value for the additive bias.
647
+ ):
648
+ super().__init__()
649
+ self.weight = torch.nn.Parameter(
650
+ torch.randn([out_features, in_features]) / lr_multiplier
651
+ )
652
+ self.bias = (
653
+ torch.nn.Parameter(torch.full([out_features], np.float32(bias_init)))
654
+ if bias
655
+ else None
656
+ )
657
+ self.activation = activation
658
+
659
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
660
+ self.bias_gain = lr_multiplier
661
+
662
+ def forward(self, x):
663
+ w = self.weight * self.weight_gain
664
+ b = self.bias
665
+ if b is not None and self.bias_gain != 1:
666
+ b = b * self.bias_gain
667
+
668
+ if self.activation == "linear" and b is not None:
669
+ # out = torch.addmm(b.unsqueeze(0), x, w.t())
670
+ x = x.matmul(w.t())
671
+ out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)])
672
+ else:
673
+ x = x.matmul(w.t())
674
+ out = bias_act(x, b, act=self.activation, dim=x.ndim - 1)
675
+ return out
676
+
677
+
678
+ def _conv2d_wrapper(
679
+ x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True
680
+ ):
681
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations."""
682
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
683
+
684
+ # Flip weight if requested.
685
+ if (
686
+ not flip_weight
687
+ ): # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
688
+ w = w.flip([2, 3])
689
+
690
+ # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
691
+ # 1x1 kernel + memory_format=channels_last + less than 64 channels.
692
+ if (
693
+ kw == 1
694
+ and kh == 1
695
+ and stride == 1
696
+ and padding in [0, [0, 0], (0, 0)]
697
+ and not transpose
698
+ ):
699
+ if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
700
+ if out_channels <= 4 and groups == 1:
701
+ in_shape = x.shape
702
+ x = w.squeeze(3).squeeze(2) @ x.reshape(
703
+ [in_shape[0], in_channels_per_group, -1]
704
+ )
705
+ x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
706
+ else:
707
+ x = x.to(memory_format=torch.contiguous_format)
708
+ w = w.to(memory_format=torch.contiguous_format)
709
+ x = conv2d(x, w, groups=groups)
710
+ return x.to(memory_format=torch.channels_last)
711
+
712
+ # Otherwise => execute using conv2d_gradfix.
713
+ op = conv_transpose2d if transpose else conv2d
714
+ return op(x, w, stride=stride, padding=padding, groups=groups)
715
+
716
+
717
+ def conv2d_resample(
718
+ x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False
719
+ ):
720
+ r"""2D convolution with optional up/downsampling.
721
+
722
+ Padding is performed only once at the beginning, not between the operations.
723
+
724
+ Args:
725
+ x: Input tensor of shape
726
+ `[batch_size, in_channels, in_height, in_width]`.
727
+ w: Weight tensor of shape
728
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
729
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
730
+ calling setup_filter(). None = identity (default).
731
+ up: Integer upsampling factor (default: 1).
732
+ down: Integer downsampling factor (default: 1).
733
+ padding: Padding with respect to the upsampled image. Can be a single number
734
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
735
+ (default: 0).
736
+ groups: Split input channels into N groups (default: 1).
737
+ flip_weight: False = convolution, True = correlation (default: True).
738
+ flip_filter: False = convolution, True = correlation (default: False).
739
+
740
+ Returns:
741
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
742
+ """
743
+ # Validate arguments.
744
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
745
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
746
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2])
747
+ assert isinstance(up, int) and (up >= 1)
748
+ assert isinstance(down, int) and (down >= 1)
749
+ # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}"
750
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
751
+ fw, fh = _get_filter_size(f)
752
+ # px0, px1, py0, py1 = _parse_padding(padding)
753
+ px0, px1, py0, py1 = padding, padding, padding, padding
754
+
755
+ # Adjust padding to account for up/downsampling.
756
+ if up > 1:
757
+ px0 += (fw + up - 1) // 2
758
+ px1 += (fw - up) // 2
759
+ py0 += (fh + up - 1) // 2
760
+ py1 += (fh - up) // 2
761
+ if down > 1:
762
+ px0 += (fw - down + 1) // 2
763
+ px1 += (fw - down) // 2
764
+ py0 += (fh - down + 1) // 2
765
+ py1 += (fh - down) // 2
766
+
767
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
768
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
769
+ x = upfirdn2d(
770
+ x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter
771
+ )
772
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
773
+ return x
774
+
775
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
776
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
777
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
778
+ x = upfirdn2d(
779
+ x=x,
780
+ f=f,
781
+ up=up,
782
+ padding=[px0, px1, py0, py1],
783
+ gain=up**2,
784
+ flip_filter=flip_filter,
785
+ )
786
+ return x
787
+
788
+ # Fast path: downsampling only => use strided convolution.
789
+ if down > 1 and up == 1:
790
+ x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
791
+ x = _conv2d_wrapper(
792
+ x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight
793
+ )
794
+ return x
795
+
796
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
797
+ if up > 1:
798
+ if groups == 1:
799
+ w = w.transpose(0, 1)
800
+ else:
801
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
802
+ w = w.transpose(1, 2)
803
+ w = w.reshape(
804
+ groups * in_channels_per_group, out_channels // groups, kh, kw
805
+ )
806
+ px0 -= kw - 1
807
+ px1 -= kw - up
808
+ py0 -= kh - 1
809
+ py1 -= kh - up
810
+ pxt = max(min(-px0, -px1), 0)
811
+ pyt = max(min(-py0, -py1), 0)
812
+ x = _conv2d_wrapper(
813
+ x=x,
814
+ w=w,
815
+ stride=up,
816
+ padding=[pyt, pxt],
817
+ groups=groups,
818
+ transpose=True,
819
+ flip_weight=(not flip_weight),
820
+ )
821
+ x = upfirdn2d(
822
+ x=x,
823
+ f=f,
824
+ padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt],
825
+ gain=up**2,
826
+ flip_filter=flip_filter,
827
+ )
828
+ if down > 1:
829
+ x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
830
+ return x
831
+
832
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
833
+ if up == 1 and down == 1:
834
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
835
+ return _conv2d_wrapper(
836
+ x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight
837
+ )
838
+
839
+ # Fallback: Generic reference implementation.
840
+ x = upfirdn2d(
841
+ x=x,
842
+ f=(f if up > 1 else None),
843
+ up=up,
844
+ padding=[px0, px1, py0, py1],
845
+ gain=up**2,
846
+ flip_filter=flip_filter,
847
+ )
848
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
849
+ if down > 1:
850
+ x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
851
+ return x
852
+
853
+
854
+ class Conv2dLayer(torch.nn.Module):
855
+ def __init__(
856
+ self,
857
+ in_channels, # Number of input channels.
858
+ out_channels, # Number of output channels.
859
+ kernel_size, # Width and height of the convolution kernel.
860
+ bias=True, # Apply additive bias before the activation function?
861
+ activation="linear", # Activation function: 'relu', 'lrelu', etc.
862
+ up=1, # Integer upsampling factor.
863
+ down=1, # Integer downsampling factor.
864
+ resample_filter=[
865
+ 1,
866
+ 3,
867
+ 3,
868
+ 1,
869
+ ], # Low-pass filter to apply when resampling activations.
870
+ conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
871
+ channels_last=False, # Expect the input to have memory_format=channels_last?
872
+ trainable=True, # Update the weights of this layer during training?
873
+ ):
874
+ super().__init__()
875
+ self.activation = activation
876
+ self.up = up
877
+ self.down = down
878
+ self.register_buffer("resample_filter", setup_filter(resample_filter))
879
+ self.conv_clamp = conv_clamp
880
+ self.padding = kernel_size // 2
881
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
882
+ self.act_gain = activation_funcs[activation].def_gain
883
+
884
+ memory_format = (
885
+ torch.channels_last if channels_last else torch.contiguous_format
886
+ )
887
+ weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(
888
+ memory_format=memory_format
889
+ )
890
+ bias = torch.zeros([out_channels]) if bias else None
891
+ if trainable:
892
+ self.weight = torch.nn.Parameter(weight)
893
+ self.bias = torch.nn.Parameter(bias) if bias is not None else None
894
+ else:
895
+ self.register_buffer("weight", weight)
896
+ if bias is not None:
897
+ self.register_buffer("bias", bias)
898
+ else:
899
+ self.bias = None
900
+
901
+ def forward(self, x, gain=1):
902
+ w = self.weight * self.weight_gain
903
+ x = conv2d_resample(
904
+ x=x,
905
+ w=w,
906
+ f=self.resample_filter,
907
+ up=self.up,
908
+ down=self.down,
909
+ padding=self.padding,
910
+ )
911
+
912
+ act_gain = self.act_gain * gain
913
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
914
+ out = bias_act(
915
+ x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp
916
+ )
917
+ return out
918
+
919
+
920
+ def torch_gc():
921
+ if torch.cuda.is_available():
922
+ torch.cuda.empty_cache()
923
+ torch.cuda.ipc_collect()
924
+ gc.collect()
925
+
926
+
927
+ def set_seed(seed: int):
928
+ random.seed(seed)
929
+ np.random.seed(seed)
930
+ torch.manual_seed(seed)
931
+ torch.cuda.manual_seed_all(seed)
932
+
933
+
934
+ def get_scheduler(sd_sampler, scheduler_config):
935
+ # https://github.com/huggingface/diffusers/issues/4167
936
+ keys_to_pop = ["use_karras_sigmas", "algorithm_type"]
937
+ scheduler_config = dict(scheduler_config)
938
+ for it in keys_to_pop:
939
+ scheduler_config.pop(it, None)
940
+
941
+ # fmt: off
942
+ samplers = {
943
+ SDSampler.dpm_plus_plus_2m: [DPMSolverMultistepScheduler],
944
+ SDSampler.dpm_plus_plus_2m_karras: [DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)],
945
+ SDSampler.dpm_plus_plus_2m_sde: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++")],
946
+ SDSampler.dpm_plus_plus_2m_sde_karras: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)],
947
+ SDSampler.dpm_plus_plus_sde: [DPMSolverSinglestepScheduler],
948
+ SDSampler.dpm_plus_plus_sde_karras: [DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)],
949
+ SDSampler.dpm2: [KDPM2DiscreteScheduler],
950
+ SDSampler.dpm2_karras: [KDPM2DiscreteScheduler, dict(use_karras_sigmas=True)],
951
+ SDSampler.dpm2_a: [KDPM2AncestralDiscreteScheduler],
952
+ SDSampler.dpm2_a_karras: [KDPM2AncestralDiscreteScheduler, dict(use_karras_sigmas=True)],
953
+ SDSampler.euler: [EulerDiscreteScheduler],
954
+ SDSampler.euler_a: [EulerAncestralDiscreteScheduler],
955
+ SDSampler.heun: [HeunDiscreteScheduler],
956
+ SDSampler.lms: [LMSDiscreteScheduler],
957
+ SDSampler.lms_karras: [LMSDiscreteScheduler, dict(use_karras_sigmas=True)],
958
+ SDSampler.ddim: [DDIMScheduler],
959
+ SDSampler.pndm: [PNDMScheduler],
960
+ SDSampler.uni_pc: [UniPCMultistepScheduler],
961
+ SDSampler.lcm: [LCMScheduler],
962
+ }
963
+ # fmt: on
964
+ if sd_sampler in samplers:
965
+ if len(samplers[sd_sampler]) == 2:
966
+ scheduler_cls, kwargs = samplers[sd_sampler]
967
+ else:
968
+ scheduler_cls, kwargs = samplers[sd_sampler][0], {}
969
+ return scheduler_cls.from_config(scheduler_config, **kwargs)
970
+ else:
971
+ raise ValueError(sd_sampler)
972
+
973
+
974
+ def is_local_files_only(**kwargs) -> bool:
975
+ from huggingface_hub.constants import HF_HUB_OFFLINE
976
+
977
+ return HF_HUB_OFFLINE or kwargs.get("local_files_only", False)
978
+
979
+
980
+ def handle_from_pretrained_exceptions(func, **kwargs):
981
+ try:
982
+ return func(**kwargs)
983
+ except ValueError as e:
984
+ if "You are trying to load the model files of the `variant=fp16`" in str(e):
985
+ logger.info("variant=fp16 not found, try revision=fp16")
986
+ try:
987
+ return func(**{**kwargs, "variant": None, "revision": "fp16"})
988
+ except Exception as e:
989
+ logger.info("revision=fp16 not found, try revision=main")
990
+ return func(**{**kwargs, "variant": None, "revision": "main"})
991
+ raise e
992
+ except OSError as e:
993
+ previous_traceback = traceback.format_exc()
994
+ if "RevisionNotFoundError: 404 Client Error." in previous_traceback:
995
+ logger.info("revision=fp16 not found, try revision=main")
996
+ return func(**{**kwargs, "variant": None, "revision": "main"})
997
+ elif "Max retries exceeded" in previous_traceback:
998
+ logger.exception(
999
+ "Fetching model from HuggingFace failed. "
1000
+ "If this is your first time downloading the model, you may need to set up proxy in terminal."
1001
+ "If the model has already been downloaded, you can add --local-files-only when starting."
1002
+ )
1003
+ exit(-1)
1004
+ raise e
1005
+ except Exception as e:
1006
+ raise e
1007
+
1008
+
1009
+ def get_torch_dtype(device, no_half: bool):
1010
+ device = str(device)
1011
+ use_fp16 = not no_half
1012
+ use_gpu = device == "cuda"
1013
+ # https://github.com/huggingface/diffusers/issues/4480
1014
+ # pipe.enable_attention_slicing and float16 will cause black output on mps
1015
+ # if device in ["cuda", "mps"] and use_fp16:
1016
+ if device in ["cuda"] and use_fp16:
1017
+ return use_gpu, torch.float16
1018
+ return use_gpu, torch.float32
1019
+
1020
+
1021
+ def enable_low_mem(pipe, enable: bool):
1022
+ if torch.backends.mps.is_available():
1023
+ # https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing
1024
+ # CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers.
1025
+ if enable:
1026
+ pipe.enable_attention_slicing("max")
1027
+ else:
1028
+ # https://huggingface.co/docs/diffusers/optimization/mps
1029
+ # Devices with less than 64GB of memory are recommended to use enable_attention_slicing
1030
+ pipe.enable_attention_slicing()
1031
+
1032
+ if enable:
1033
+ pipe.vae.enable_tiling()
iopaint/model/zits.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import cv2
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from iopaint.helper import get_cache_path_by_url, load_jit_model, download_model
9
+ from iopaint.schema import InpaintRequest
10
+ import numpy as np
11
+
12
+ from .base import InpaintModel
13
+
14
+ ZITS_INPAINT_MODEL_URL = os.environ.get(
15
+ "ZITS_INPAINT_MODEL_URL",
16
+ "https://github.com/Sanster/models/releases/download/add_zits/zits-inpaint-0717.pt",
17
+ )
18
+ ZITS_INPAINT_MODEL_MD5 = os.environ.get(
19
+ "ZITS_INPAINT_MODEL_MD5", "9978cc7157dc29699e42308d675b2154"
20
+ )
21
+
22
+ ZITS_EDGE_LINE_MODEL_URL = os.environ.get(
23
+ "ZITS_EDGE_LINE_MODEL_URL",
24
+ "https://github.com/Sanster/models/releases/download/add_zits/zits-edge-line-0717.pt",
25
+ )
26
+ ZITS_EDGE_LINE_MODEL_MD5 = os.environ.get(
27
+ "ZITS_EDGE_LINE_MODEL_MD5", "55e31af21ba96bbf0c80603c76ea8c5f"
28
+ )
29
+
30
+ ZITS_STRUCTURE_UPSAMPLE_MODEL_URL = os.environ.get(
31
+ "ZITS_STRUCTURE_UPSAMPLE_MODEL_URL",
32
+ "https://github.com/Sanster/models/releases/download/add_zits/zits-structure-upsample-0717.pt",
33
+ )
34
+ ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5 = os.environ.get(
35
+ "ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5", "3d88a07211bd41b2ec8cc0d999f29927"
36
+ )
37
+
38
+ ZITS_WIRE_FRAME_MODEL_URL = os.environ.get(
39
+ "ZITS_WIRE_FRAME_MODEL_URL",
40
+ "https://github.com/Sanster/models/releases/download/add_zits/zits-wireframe-0717.pt",
41
+ )
42
+ ZITS_WIRE_FRAME_MODEL_MD5 = os.environ.get(
43
+ "ZITS_WIRE_FRAME_MODEL_MD5", "a9727c63a8b48b65c905d351b21ce46b"
44
+ )
45
+
46
+
47
+ def resize(img, height, width, center_crop=False):
48
+ imgh, imgw = img.shape[0:2]
49
+
50
+ if center_crop and imgh != imgw:
51
+ # center crop
52
+ side = np.minimum(imgh, imgw)
53
+ j = (imgh - side) // 2
54
+ i = (imgw - side) // 2
55
+ img = img[j : j + side, i : i + side, ...]
56
+
57
+ if imgh > height and imgw > width:
58
+ inter = cv2.INTER_AREA
59
+ else:
60
+ inter = cv2.INTER_LINEAR
61
+ img = cv2.resize(img, (height, width), interpolation=inter)
62
+
63
+ return img
64
+
65
+
66
+ def to_tensor(img, scale=True, norm=False):
67
+ if img.ndim == 2:
68
+ img = img[:, :, np.newaxis]
69
+ c = img.shape[-1]
70
+
71
+ if scale:
72
+ img_t = torch.from_numpy(img).permute(2, 0, 1).float().div(255)
73
+ else:
74
+ img_t = torch.from_numpy(img).permute(2, 0, 1).float()
75
+
76
+ if norm:
77
+ mean = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1)
78
+ std = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1)
79
+ img_t = (img_t - mean) / std
80
+ return img_t
81
+
82
+
83
+ def load_masked_position_encoding(mask):
84
+ ones_filter = np.ones((3, 3), dtype=np.float32)
85
+ d_filter1 = np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32)
86
+ d_filter2 = np.array([[0, 0, 0], [1, 1, 0], [1, 1, 0]], dtype=np.float32)
87
+ d_filter3 = np.array([[0, 1, 1], [0, 1, 1], [0, 0, 0]], dtype=np.float32)
88
+ d_filter4 = np.array([[0, 0, 0], [0, 1, 1], [0, 1, 1]], dtype=np.float32)
89
+ str_size = 256
90
+ pos_num = 128
91
+
92
+ ori_mask = mask.copy()
93
+ ori_h, ori_w = ori_mask.shape[0:2]
94
+ ori_mask = ori_mask / 255
95
+ mask = cv2.resize(mask, (str_size, str_size), interpolation=cv2.INTER_AREA)
96
+ mask[mask > 0] = 255
97
+ h, w = mask.shape[0:2]
98
+ mask3 = mask.copy()
99
+ mask3 = 1.0 - (mask3 / 255.0)
100
+ pos = np.zeros((h, w), dtype=np.int32)
101
+ direct = np.zeros((h, w, 4), dtype=np.int32)
102
+ i = 0
103
+ while np.sum(1 - mask3) > 0:
104
+ i += 1
105
+ mask3_ = cv2.filter2D(mask3, -1, ones_filter)
106
+ mask3_[mask3_ > 0] = 1
107
+ sub_mask = mask3_ - mask3
108
+ pos[sub_mask == 1] = i
109
+
110
+ m = cv2.filter2D(mask3, -1, d_filter1)
111
+ m[m > 0] = 1
112
+ m = m - mask3
113
+ direct[m == 1, 0] = 1
114
+
115
+ m = cv2.filter2D(mask3, -1, d_filter2)
116
+ m[m > 0] = 1
117
+ m = m - mask3
118
+ direct[m == 1, 1] = 1
119
+
120
+ m = cv2.filter2D(mask3, -1, d_filter3)
121
+ m[m > 0] = 1
122
+ m = m - mask3
123
+ direct[m == 1, 2] = 1
124
+
125
+ m = cv2.filter2D(mask3, -1, d_filter4)
126
+ m[m > 0] = 1
127
+ m = m - mask3
128
+ direct[m == 1, 3] = 1
129
+
130
+ mask3 = mask3_
131
+
132
+ abs_pos = pos.copy()
133
+ rel_pos = pos / (str_size / 2) # to 0~1 maybe larger than 1
134
+ rel_pos = (rel_pos * pos_num).astype(np.int32)
135
+ rel_pos = np.clip(rel_pos, 0, pos_num - 1)
136
+
137
+ if ori_w != w or ori_h != h:
138
+ rel_pos = cv2.resize(rel_pos, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST)
139
+ rel_pos[ori_mask == 0] = 0
140
+ direct = cv2.resize(direct, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST)
141
+ direct[ori_mask == 0, :] = 0
142
+
143
+ return rel_pos, abs_pos, direct
144
+
145
+
146
+ def load_image(img, mask, device, sigma256=3.0):
147
+ """
148
+ Args:
149
+ img: [H, W, C] RGB
150
+ mask: [H, W] 255 为 masks 区域
151
+ sigma256:
152
+
153
+ Returns:
154
+
155
+ """
156
+ h, w, _ = img.shape
157
+ imgh, imgw = img.shape[0:2]
158
+ img_256 = resize(img, 256, 256)
159
+
160
+ mask = (mask > 127).astype(np.uint8) * 255
161
+ mask_256 = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_AREA)
162
+ mask_256[mask_256 > 0] = 255
163
+
164
+ mask_512 = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_AREA)
165
+ mask_512[mask_512 > 0] = 255
166
+
167
+ # original skimage implemention
168
+ # https://scikit-image.org/docs/stable/api/skimage.feature.html#skimage.feature.canny
169
+ # low_threshold: Lower bound for hysteresis thresholding (linking edges). If None, low_threshold is set to 10% of dtype’s max.
170
+ # high_threshold: Upper bound for hysteresis thresholding (linking edges). If None, high_threshold is set to 20% of dtype’s max.
171
+
172
+ try:
173
+ import skimage
174
+
175
+ gray_256 = skimage.color.rgb2gray(img_256)
176
+ edge_256 = skimage.feature.canny(gray_256, sigma=3.0, mask=None).astype(float)
177
+ # cv2.imwrite("skimage_gray.jpg", (gray_256*255).astype(np.uint8))
178
+ # cv2.imwrite("skimage_edge.jpg", (edge_256*255).astype(np.uint8))
179
+ except:
180
+ gray_256 = cv2.cvtColor(img_256, cv2.COLOR_RGB2GRAY)
181
+ gray_256_blured = cv2.GaussianBlur(
182
+ gray_256, ksize=(7, 7), sigmaX=sigma256, sigmaY=sigma256
183
+ )
184
+ edge_256 = cv2.Canny(
185
+ gray_256_blured, threshold1=int(255 * 0.1), threshold2=int(255 * 0.2)
186
+ )
187
+
188
+ # cv2.imwrite("opencv_edge.jpg", edge_256)
189
+
190
+ # line
191
+ img_512 = resize(img, 512, 512)
192
+
193
+ rel_pos, abs_pos, direct = load_masked_position_encoding(mask)
194
+
195
+ batch = dict()
196
+ batch["images"] = to_tensor(img.copy()).unsqueeze(0).to(device)
197
+ batch["img_256"] = to_tensor(img_256, norm=True).unsqueeze(0).to(device)
198
+ batch["masks"] = to_tensor(mask).unsqueeze(0).to(device)
199
+ batch["mask_256"] = to_tensor(mask_256).unsqueeze(0).to(device)
200
+ batch["mask_512"] = to_tensor(mask_512).unsqueeze(0).to(device)
201
+ batch["edge_256"] = to_tensor(edge_256, scale=False).unsqueeze(0).to(device)
202
+ batch["img_512"] = to_tensor(img_512).unsqueeze(0).to(device)
203
+ batch["rel_pos"] = torch.LongTensor(rel_pos).unsqueeze(0).to(device)
204
+ batch["abs_pos"] = torch.LongTensor(abs_pos).unsqueeze(0).to(device)
205
+ batch["direct"] = torch.LongTensor(direct).unsqueeze(0).to(device)
206
+ batch["h"] = imgh
207
+ batch["w"] = imgw
208
+
209
+ return batch
210
+
211
+
212
+ def to_device(data, device):
213
+ if isinstance(data, torch.Tensor):
214
+ return data.to(device)
215
+ if isinstance(data, dict):
216
+ for key in data:
217
+ if isinstance(data[key], torch.Tensor):
218
+ data[key] = data[key].to(device)
219
+ return data
220
+ if isinstance(data, list):
221
+ return [to_device(d, device) for d in data]
222
+
223
+
224
+ class ZITS(InpaintModel):
225
+ name = "zits"
226
+ min_size = 256
227
+ pad_mod = 32
228
+ pad_to_square = True
229
+ is_erase_model = True
230
+
231
+ def __init__(self, device, **kwargs):
232
+ """
233
+
234
+ Args:
235
+ device:
236
+ """
237
+ super().__init__(device)
238
+ self.device = device
239
+ self.sample_edge_line_iterations = 1
240
+
241
+ def init_model(self, device, **kwargs):
242
+ self.wireframe = load_jit_model(
243
+ ZITS_WIRE_FRAME_MODEL_URL, device, ZITS_WIRE_FRAME_MODEL_MD5
244
+ )
245
+ self.edge_line = load_jit_model(
246
+ ZITS_EDGE_LINE_MODEL_URL, device, ZITS_EDGE_LINE_MODEL_MD5
247
+ )
248
+ self.structure_upsample = load_jit_model(
249
+ ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device, ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5
250
+ )
251
+ self.inpaint = load_jit_model(
252
+ ZITS_INPAINT_MODEL_URL, device, ZITS_INPAINT_MODEL_MD5
253
+ )
254
+
255
+ @staticmethod
256
+ def download():
257
+ download_model(ZITS_WIRE_FRAME_MODEL_URL, ZITS_WIRE_FRAME_MODEL_MD5)
258
+ download_model(ZITS_EDGE_LINE_MODEL_URL, ZITS_EDGE_LINE_MODEL_MD5)
259
+ download_model(
260
+ ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5
261
+ )
262
+ download_model(ZITS_INPAINT_MODEL_URL, ZITS_INPAINT_MODEL_MD5)
263
+
264
+ @staticmethod
265
+ def is_downloaded() -> bool:
266
+ model_paths = [
267
+ get_cache_path_by_url(ZITS_WIRE_FRAME_MODEL_URL),
268
+ get_cache_path_by_url(ZITS_EDGE_LINE_MODEL_URL),
269
+ get_cache_path_by_url(ZITS_STRUCTURE_UPSAMPLE_MODEL_URL),
270
+ get_cache_path_by_url(ZITS_INPAINT_MODEL_URL),
271
+ ]
272
+ return all([os.path.exists(it) for it in model_paths])
273
+
274
+ def wireframe_edge_and_line(self, items, enable: bool):
275
+ # 最终向 items 中添加 edge 和 line key
276
+ if not enable:
277
+ items["edge"] = torch.zeros_like(items["masks"])
278
+ items["line"] = torch.zeros_like(items["masks"])
279
+ return
280
+
281
+ start = time.time()
282
+ try:
283
+ line_256 = self.wireframe_forward(
284
+ items["img_512"],
285
+ h=256,
286
+ w=256,
287
+ masks=items["mask_512"],
288
+ mask_th=0.85,
289
+ )
290
+ except:
291
+ line_256 = torch.zeros_like(items["mask_256"])
292
+
293
+ print(f"wireframe_forward time: {(time.time() - start) * 1000:.2f}ms")
294
+
295
+ # np_line = (line[0][0].numpy() * 255).astype(np.uint8)
296
+ # cv2.imwrite("line.jpg", np_line)
297
+
298
+ start = time.time()
299
+ edge_pred, line_pred = self.sample_edge_line_logits(
300
+ context=[items["img_256"], items["edge_256"], line_256],
301
+ mask=items["mask_256"].clone(),
302
+ iterations=self.sample_edge_line_iterations,
303
+ add_v=0.05,
304
+ mul_v=4,
305
+ )
306
+ print(f"sample_edge_line_logits time: {(time.time() - start) * 1000:.2f}ms")
307
+
308
+ # np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8)
309
+ # cv2.imwrite("edge_pred.jpg", np_edge_pred)
310
+ # np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8)
311
+ # cv2.imwrite("line_pred.jpg", np_line_pred)
312
+ # exit()
313
+
314
+ input_size = min(items["h"], items["w"])
315
+ if input_size != 256 and input_size > 256:
316
+ while edge_pred.shape[2] < input_size:
317
+ edge_pred = self.structure_upsample(edge_pred)
318
+ edge_pred = torch.sigmoid((edge_pred + 2) * 2)
319
+
320
+ line_pred = self.structure_upsample(line_pred)
321
+ line_pred = torch.sigmoid((line_pred + 2) * 2)
322
+
323
+ edge_pred = F.interpolate(
324
+ edge_pred,
325
+ size=(input_size, input_size),
326
+ mode="bilinear",
327
+ align_corners=False,
328
+ )
329
+ line_pred = F.interpolate(
330
+ line_pred,
331
+ size=(input_size, input_size),
332
+ mode="bilinear",
333
+ align_corners=False,
334
+ )
335
+
336
+ # np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8)
337
+ # cv2.imwrite("edge_pred_upsample.jpg", np_edge_pred)
338
+ # np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8)
339
+ # cv2.imwrite("line_pred_upsample.jpg", np_line_pred)
340
+ # exit()
341
+
342
+ items["edge"] = edge_pred.detach()
343
+ items["line"] = line_pred.detach()
344
+
345
+ @torch.no_grad()
346
+ def forward(self, image, mask, config: InpaintRequest):
347
+ """Input images and output images have same size
348
+ images: [H, W, C] RGB
349
+ masks: [H, W]
350
+ return: BGR IMAGE
351
+ """
352
+ mask = mask[:, :, 0]
353
+ items = load_image(image, mask, device=self.device)
354
+
355
+ self.wireframe_edge_and_line(items, config.zits_wireframe)
356
+
357
+ inpainted_image = self.inpaint(
358
+ items["images"],
359
+ items["masks"],
360
+ items["edge"],
361
+ items["line"],
362
+ items["rel_pos"],
363
+ items["direct"],
364
+ )
365
+
366
+ inpainted_image = inpainted_image * 255.0
367
+ inpainted_image = (
368
+ inpainted_image.cpu().permute(0, 2, 3, 1)[0].numpy().astype(np.uint8)
369
+ )
370
+ inpainted_image = inpainted_image[:, :, ::-1]
371
+
372
+ # cv2.imwrite("inpainted.jpg", inpainted_image)
373
+ # exit()
374
+
375
+ return inpainted_image
376
+
377
+ def wireframe_forward(self, images, h, w, masks, mask_th=0.925):
378
+ lcnn_mean = torch.tensor([109.730, 103.832, 98.681]).reshape(1, 3, 1, 1)
379
+ lcnn_std = torch.tensor([22.275, 22.124, 23.229]).reshape(1, 3, 1, 1)
380
+ images = images * 255.0
381
+ # the masks value of lcnn is 127.5
382
+ masked_images = images * (1 - masks) + torch.ones_like(images) * masks * 127.5
383
+ masked_images = (masked_images - lcnn_mean) / lcnn_std
384
+
385
+ def to_int(x):
386
+ return tuple(map(int, x))
387
+
388
+ lines_tensor = []
389
+ lmap = np.zeros((h, w))
390
+
391
+ output_masked = self.wireframe(masked_images)
392
+
393
+ output_masked = to_device(output_masked, "cpu")
394
+ if output_masked["num_proposals"] == 0:
395
+ lines_masked = []
396
+ scores_masked = []
397
+ else:
398
+ lines_masked = output_masked["lines_pred"].numpy()
399
+ lines_masked = [
400
+ [line[1] * h, line[0] * w, line[3] * h, line[2] * w]
401
+ for line in lines_masked
402
+ ]
403
+ scores_masked = output_masked["lines_score"].numpy()
404
+
405
+ for line, score in zip(lines_masked, scores_masked):
406
+ if score > mask_th:
407
+ try:
408
+ import skimage
409
+
410
+ rr, cc, value = skimage.draw.line_aa(
411
+ *to_int(line[0:2]), *to_int(line[2:4])
412
+ )
413
+ lmap[rr, cc] = np.maximum(lmap[rr, cc], value)
414
+ except:
415
+ cv2.line(
416
+ lmap,
417
+ to_int(line[0:2][::-1]),
418
+ to_int(line[2:4][::-1]),
419
+ (1, 1, 1),
420
+ 1,
421
+ cv2.LINE_AA,
422
+ )
423
+
424
+ lmap = np.clip(lmap * 255, 0, 255).astype(np.uint8)
425
+ lines_tensor.append(to_tensor(lmap).unsqueeze(0))
426
+
427
+ lines_tensor = torch.cat(lines_tensor, dim=0)
428
+ return lines_tensor.detach().to(self.device)
429
+
430
+ def sample_edge_line_logits(
431
+ self, context, mask=None, iterations=1, add_v=0, mul_v=4
432
+ ):
433
+ [img, edge, line] = context
434
+
435
+ img = img * (1 - mask)
436
+ edge = edge * (1 - mask)
437
+ line = line * (1 - mask)
438
+
439
+ for i in range(iterations):
440
+ edge_logits, line_logits = self.edge_line(img, edge, line, masks=mask)
441
+
442
+ edge_pred = torch.sigmoid(edge_logits)
443
+ line_pred = torch.sigmoid((line_logits + add_v) * mul_v)
444
+ edge = edge + edge_pred * mask
445
+ edge[edge >= 0.25] = 1
446
+ edge[edge < 0.25] = 0
447
+ line = line + line_pred * mask
448
+
449
+ b, _, h, w = edge_pred.shape
450
+ edge_pred = edge_pred.reshape(b, -1, 1)
451
+ line_pred = line_pred.reshape(b, -1, 1)
452
+ mask = mask.reshape(b, -1)
453
+
454
+ edge_probs = torch.cat([1 - edge_pred, edge_pred], dim=-1)
455
+ line_probs = torch.cat([1 - line_pred, line_pred], dim=-1)
456
+ edge_probs[:, :, 1] += 0.5
457
+ line_probs[:, :, 1] += 0.5
458
+ edge_max_probs = edge_probs.max(dim=-1)[0] + (1 - mask) * (-100)
459
+ line_max_probs = line_probs.max(dim=-1)[0] + (1 - mask) * (-100)
460
+
461
+ indices = torch.sort(
462
+ edge_max_probs + line_max_probs, dim=-1, descending=True
463
+ )[1]
464
+
465
+ for ii in range(b):
466
+ keep = int((i + 1) / iterations * torch.sum(mask[ii, ...]))
467
+
468
+ assert torch.sum(mask[ii][indices[ii, :keep]]) == keep, "Error!!!"
469
+ mask[ii][indices[ii, :keep]] = 0
470
+
471
+ mask = mask.reshape(b, 1, h, w)
472
+ edge = edge * (1 - mask)
473
+ line = line * (1 - mask)
474
+
475
+ edge, line = edge.to(torch.float32), line.to(torch.float32)
476
+ return edge, line
iopaint/plugins/segment_anything/modeling/tiny_vit_sam.py ADDED
@@ -0,0 +1,822 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # TinyViT Model Architecture
3
+ # Copyright (c) 2022 Microsoft
4
+ # Adapted from LeViT and Swin Transformer
5
+ # LeViT: (https://github.com/facebookresearch/levit)
6
+ # Swin: (https://github.com/microsoft/swin-transformer)
7
+ # Build the TinyViT Model
8
+ # --------------------------------------------------------
9
+
10
+ import collections
11
+ import itertools
12
+ import math
13
+ import warnings
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import torch.utils.checkpoint as checkpoint
18
+ from typing import Tuple
19
+
20
+
21
+ def _ntuple(n):
22
+ def parse(x):
23
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
24
+ return x
25
+ return tuple(itertools.repeat(x, n))
26
+
27
+ return parse
28
+
29
+
30
+ to_2tuple = _ntuple(2)
31
+
32
+
33
+ def _trunc_normal_(tensor, mean, std, a, b):
34
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
35
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
36
+ def norm_cdf(x):
37
+ # Computes standard normal cumulative distribution function
38
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
39
+
40
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
41
+ warnings.warn(
42
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
43
+ "The distribution of values may be incorrect.",
44
+ stacklevel=2,
45
+ )
46
+
47
+ # Values are generated by using a truncated uniform distribution and
48
+ # then using the inverse CDF for the normal distribution.
49
+ # Get upper and lower cdf values
50
+ l = norm_cdf((a - mean) / std)
51
+ u = norm_cdf((b - mean) / std)
52
+
53
+ # Uniformly fill tensor with values from [l, u], then translate to
54
+ # [2l-1, 2u-1].
55
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
56
+
57
+ # Use inverse cdf transform for normal distribution to get truncated
58
+ # standard normal
59
+ tensor.erfinv_()
60
+
61
+ # Transform to proper mean, std
62
+ tensor.mul_(std * math.sqrt(2.0))
63
+ tensor.add_(mean)
64
+
65
+ # Clamp to ensure it's in the proper range
66
+ tensor.clamp_(min=a, max=b)
67
+ return tensor
68
+
69
+
70
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
71
+ # type: (Tensor, float, float, float, float) -> Tensor
72
+ r"""Fills the input Tensor with values drawn from a truncated
73
+ normal distribution. The values are effectively drawn from the
74
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
75
+ with values outside :math:`[a, b]` redrawn until they are within
76
+ the bounds. The method used for generating the random values works
77
+ best when :math:`a \leq \text{mean} \leq b`.
78
+
79
+ NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
80
+ applied while sampling the normal with mean/std applied, therefore a, b args
81
+ should be adjusted to match the range of mean, std args.
82
+
83
+ Args:
84
+ tensor: an n-dimensional `torch.Tensor`
85
+ mean: the mean of the normal distribution
86
+ std: the standard deviation of the normal distribution
87
+ a: the minimum cutoff value
88
+ b: the maximum cutoff value
89
+ Examples:
90
+ >>> w = torch.empty(3, 5)
91
+ >>> nn.init.trunc_normal_(w)
92
+ """
93
+ with torch.no_grad():
94
+ return _trunc_normal_(tensor, mean, std, a, b)
95
+
96
+
97
+ def drop_path(
98
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
99
+ ):
100
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
101
+
102
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
103
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
104
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
105
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
106
+ 'survival rate' as the argument.
107
+
108
+ """
109
+ if drop_prob == 0.0 or not training:
110
+ return x
111
+ keep_prob = 1 - drop_prob
112
+ shape = (x.shape[0],) + (1,) * (
113
+ x.ndim - 1
114
+ ) # work with diff dim tensors, not just 2D ConvNets
115
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
116
+ if keep_prob > 0.0 and scale_by_keep:
117
+ random_tensor.div_(keep_prob)
118
+ return x * random_tensor
119
+
120
+
121
+ class TimmDropPath(nn.Module):
122
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
123
+
124
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
125
+ super(TimmDropPath, self).__init__()
126
+ self.drop_prob = drop_prob
127
+ self.scale_by_keep = scale_by_keep
128
+
129
+ def forward(self, x):
130
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
131
+
132
+ def extra_repr(self):
133
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
134
+
135
+
136
+ class Conv2d_BN(torch.nn.Sequential):
137
+ def __init__(
138
+ self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1
139
+ ):
140
+ super().__init__()
141
+ self.add_module(
142
+ "c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)
143
+ )
144
+ bn = torch.nn.BatchNorm2d(b)
145
+ torch.nn.init.constant_(bn.weight, bn_weight_init)
146
+ torch.nn.init.constant_(bn.bias, 0)
147
+ self.add_module("bn", bn)
148
+
149
+ @torch.no_grad()
150
+ def fuse(self):
151
+ c, bn = self._modules.values()
152
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
153
+ w = c.weight * w[:, None, None, None]
154
+ b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
155
+ m = torch.nn.Conv2d(
156
+ w.size(1) * self.c.groups,
157
+ w.size(0),
158
+ w.shape[2:],
159
+ stride=self.c.stride,
160
+ padding=self.c.padding,
161
+ dilation=self.c.dilation,
162
+ groups=self.c.groups,
163
+ )
164
+ m.weight.data.copy_(w)
165
+ m.bias.data.copy_(b)
166
+ return m
167
+
168
+
169
+ class DropPath(TimmDropPath):
170
+ def __init__(self, drop_prob=None):
171
+ super().__init__(drop_prob=drop_prob)
172
+ self.drop_prob = drop_prob
173
+
174
+ def __repr__(self):
175
+ msg = super().__repr__()
176
+ msg += f"(drop_prob={self.drop_prob})"
177
+ return msg
178
+
179
+
180
+ class PatchEmbed(nn.Module):
181
+ def __init__(self, in_chans, embed_dim, resolution, activation):
182
+ super().__init__()
183
+ img_size: Tuple[int, int] = to_2tuple(resolution)
184
+ self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
185
+ self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
186
+ self.in_chans = in_chans
187
+ self.embed_dim = embed_dim
188
+ n = embed_dim
189
+ self.seq = nn.Sequential(
190
+ Conv2d_BN(in_chans, n // 2, 3, 2, 1),
191
+ activation(),
192
+ Conv2d_BN(n // 2, n, 3, 2, 1),
193
+ )
194
+
195
+ def forward(self, x):
196
+ return self.seq(x)
197
+
198
+
199
+ class MBConv(nn.Module):
200
+ def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
201
+ super().__init__()
202
+ self.in_chans = in_chans
203
+ self.hidden_chans = int(in_chans * expand_ratio)
204
+ self.out_chans = out_chans
205
+
206
+ self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
207
+ self.act1 = activation()
208
+
209
+ self.conv2 = Conv2d_BN(
210
+ self.hidden_chans,
211
+ self.hidden_chans,
212
+ ks=3,
213
+ stride=1,
214
+ pad=1,
215
+ groups=self.hidden_chans,
216
+ )
217
+ self.act2 = activation()
218
+
219
+ self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
220
+ self.act3 = activation()
221
+
222
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
223
+
224
+ def forward(self, x):
225
+ shortcut = x
226
+
227
+ x = self.conv1(x)
228
+ x = self.act1(x)
229
+
230
+ x = self.conv2(x)
231
+ x = self.act2(x)
232
+
233
+ x = self.conv3(x)
234
+
235
+ x = self.drop_path(x)
236
+
237
+ x += shortcut
238
+ x = self.act3(x)
239
+
240
+ return x
241
+
242
+
243
+ class PatchMerging(nn.Module):
244
+ def __init__(self, input_resolution, dim, out_dim, activation):
245
+ super().__init__()
246
+
247
+ self.input_resolution = input_resolution
248
+ self.dim = dim
249
+ self.out_dim = out_dim
250
+ self.act = activation()
251
+ self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
252
+ stride_c = 2
253
+ if out_dim == 320 or out_dim == 448 or out_dim == 576:
254
+ stride_c = 1
255
+ self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
256
+ self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
257
+
258
+ def forward(self, x):
259
+ if x.ndim == 3:
260
+ H, W = self.input_resolution
261
+ B = len(x)
262
+ # (B, C, H, W)
263
+ x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
264
+
265
+ x = self.conv1(x)
266
+ x = self.act(x)
267
+
268
+ x = self.conv2(x)
269
+ x = self.act(x)
270
+ x = self.conv3(x)
271
+ x = x.flatten(2).transpose(1, 2)
272
+ return x
273
+
274
+
275
+ class ConvLayer(nn.Module):
276
+ def __init__(
277
+ self,
278
+ dim,
279
+ input_resolution,
280
+ depth,
281
+ activation,
282
+ drop_path=0.0,
283
+ downsample=None,
284
+ use_checkpoint=False,
285
+ out_dim=None,
286
+ conv_expand_ratio=4.0,
287
+ ):
288
+ super().__init__()
289
+ self.dim = dim
290
+ self.input_resolution = input_resolution
291
+ self.depth = depth
292
+ self.use_checkpoint = use_checkpoint
293
+
294
+ # build blocks
295
+ self.blocks = nn.ModuleList(
296
+ [
297
+ MBConv(
298
+ dim,
299
+ dim,
300
+ conv_expand_ratio,
301
+ activation,
302
+ drop_path[i] if isinstance(drop_path, list) else drop_path,
303
+ )
304
+ for i in range(depth)
305
+ ]
306
+ )
307
+
308
+ # patch merging layer
309
+ if downsample is not None:
310
+ self.downsample = downsample(
311
+ input_resolution, dim=dim, out_dim=out_dim, activation=activation
312
+ )
313
+ else:
314
+ self.downsample = None
315
+
316
+ def forward(self, x):
317
+ for blk in self.blocks:
318
+ if self.use_checkpoint:
319
+ x = checkpoint.checkpoint(blk, x)
320
+ else:
321
+ x = blk(x)
322
+ if self.downsample is not None:
323
+ x = self.downsample(x)
324
+ return x
325
+
326
+
327
+ class Mlp(nn.Module):
328
+ def __init__(
329
+ self,
330
+ in_features,
331
+ hidden_features=None,
332
+ out_features=None,
333
+ act_layer=nn.GELU,
334
+ drop=0.0,
335
+ ):
336
+ super().__init__()
337
+ out_features = out_features or in_features
338
+ hidden_features = hidden_features or in_features
339
+ self.norm = nn.LayerNorm(in_features)
340
+ self.fc1 = nn.Linear(in_features, hidden_features)
341
+ self.fc2 = nn.Linear(hidden_features, out_features)
342
+ self.act = act_layer()
343
+ self.drop = nn.Dropout(drop)
344
+
345
+ def forward(self, x):
346
+ x = self.norm(x)
347
+
348
+ x = self.fc1(x)
349
+ x = self.act(x)
350
+ x = self.drop(x)
351
+ x = self.fc2(x)
352
+ x = self.drop(x)
353
+ return x
354
+
355
+
356
+ class Attention(torch.nn.Module):
357
+ def __init__(
358
+ self,
359
+ dim,
360
+ key_dim,
361
+ num_heads=8,
362
+ attn_ratio=4,
363
+ resolution=(14, 14),
364
+ ):
365
+ super().__init__()
366
+ # (h, w)
367
+ assert isinstance(resolution, tuple) and len(resolution) == 2
368
+ self.num_heads = num_heads
369
+ self.scale = key_dim**-0.5
370
+ self.key_dim = key_dim
371
+ self.nh_kd = nh_kd = key_dim * num_heads
372
+ self.d = int(attn_ratio * key_dim)
373
+ self.dh = int(attn_ratio * key_dim) * num_heads
374
+ self.attn_ratio = attn_ratio
375
+ h = self.dh + nh_kd * 2
376
+
377
+ self.norm = nn.LayerNorm(dim)
378
+ self.qkv = nn.Linear(dim, h)
379
+ self.proj = nn.Linear(self.dh, dim)
380
+
381
+ points = list(itertools.product(range(resolution[0]), range(resolution[1])))
382
+ N = len(points)
383
+ attention_offsets = {}
384
+ idxs = []
385
+ for p1 in points:
386
+ for p2 in points:
387
+ offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
388
+ if offset not in attention_offsets:
389
+ attention_offsets[offset] = len(attention_offsets)
390
+ idxs.append(attention_offsets[offset])
391
+ self.attention_biases = torch.nn.Parameter(
392
+ torch.zeros(num_heads, len(attention_offsets))
393
+ )
394
+ self.register_buffer(
395
+ "attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False
396
+ )
397
+
398
+ @torch.no_grad()
399
+ def train(self, mode=True):
400
+ super().train(mode)
401
+ if mode and hasattr(self, "ab"):
402
+ del self.ab
403
+ else:
404
+ self.register_buffer(
405
+ "ab",
406
+ self.attention_biases[:, self.attention_bias_idxs],
407
+ persistent=False,
408
+ )
409
+
410
+ def forward(self, x): # x (B,N,C)
411
+ B, N, _ = x.shape
412
+
413
+ # Normalization
414
+ x = self.norm(x)
415
+
416
+ qkv = self.qkv(x)
417
+ # (B, N, num_heads, d)
418
+ q, k, v = qkv.view(B, N, self.num_heads, -1).split(
419
+ [self.key_dim, self.key_dim, self.d], dim=3
420
+ )
421
+ # (B, num_heads, N, d)
422
+ q = q.permute(0, 2, 1, 3)
423
+ k = k.permute(0, 2, 1, 3)
424
+ v = v.permute(0, 2, 1, 3)
425
+
426
+ attn = (q @ k.transpose(-2, -1)) * self.scale + (
427
+ self.attention_biases[:, self.attention_bias_idxs]
428
+ if self.training
429
+ else self.ab
430
+ )
431
+ attn = attn.softmax(dim=-1)
432
+ x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
433
+ x = self.proj(x)
434
+ return x
435
+
436
+
437
+ class TinyViTBlock(nn.Module):
438
+ r"""TinyViT Block.
439
+
440
+ Args:
441
+ dim (int): Number of input channels.
442
+ input_resolution (tuple[int, int]): Input resolution.
443
+ num_heads (int): Number of attention heads.
444
+ window_size (int): Window size.
445
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
446
+ drop (float, optional): Dropout rate. Default: 0.0
447
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
448
+ local_conv_size (int): the kernel size of the convolution between
449
+ Attention and MLP. Default: 3
450
+ activation: the activation function. Default: nn.GELU
451
+ """
452
+
453
+ def __init__(
454
+ self,
455
+ dim,
456
+ input_resolution,
457
+ num_heads,
458
+ window_size=7,
459
+ mlp_ratio=4.0,
460
+ drop=0.0,
461
+ drop_path=0.0,
462
+ local_conv_size=3,
463
+ activation=nn.GELU,
464
+ ):
465
+ super().__init__()
466
+ self.dim = dim
467
+ self.input_resolution = input_resolution
468
+ self.num_heads = num_heads
469
+ assert window_size > 0, "window_size must be greater than 0"
470
+ self.window_size = window_size
471
+ self.mlp_ratio = mlp_ratio
472
+
473
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
474
+
475
+ assert dim % num_heads == 0, "dim must be divisible by num_heads"
476
+ head_dim = dim // num_heads
477
+
478
+ window_resolution = (window_size, window_size)
479
+ self.attn = Attention(
480
+ dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution
481
+ )
482
+
483
+ mlp_hidden_dim = int(dim * mlp_ratio)
484
+ mlp_activation = activation
485
+ self.mlp = Mlp(
486
+ in_features=dim,
487
+ hidden_features=mlp_hidden_dim,
488
+ act_layer=mlp_activation,
489
+ drop=drop,
490
+ )
491
+
492
+ pad = local_conv_size // 2
493
+ self.local_conv = Conv2d_BN(
494
+ dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim
495
+ )
496
+
497
+ def forward(self, x):
498
+ H, W = self.input_resolution
499
+ B, L, C = x.shape
500
+ assert L == H * W, "input feature has wrong size"
501
+ res_x = x
502
+ if H == self.window_size and W == self.window_size:
503
+ x = self.attn(x)
504
+ else:
505
+ x = x.view(B, H, W, C)
506
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
507
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
508
+ padding = pad_b > 0 or pad_r > 0
509
+
510
+ if padding:
511
+ x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
512
+
513
+ pH, pW = H + pad_b, W + pad_r
514
+ nH = pH // self.window_size
515
+ nW = pW // self.window_size
516
+ # window partition
517
+ x = (
518
+ x.view(B, nH, self.window_size, nW, self.window_size, C)
519
+ .transpose(2, 3)
520
+ .reshape(B * nH * nW, self.window_size * self.window_size, C)
521
+ )
522
+ x = self.attn(x)
523
+ # window reverse
524
+ x = (
525
+ x.view(B, nH, nW, self.window_size, self.window_size, C)
526
+ .transpose(2, 3)
527
+ .reshape(B, pH, pW, C)
528
+ )
529
+
530
+ if padding:
531
+ x = x[:, :H, :W].contiguous()
532
+
533
+ x = x.view(B, L, C)
534
+
535
+ x = res_x + self.drop_path(x)
536
+
537
+ x = x.transpose(1, 2).reshape(B, C, H, W)
538
+ x = self.local_conv(x)
539
+ x = x.view(B, C, L).transpose(1, 2)
540
+
541
+ x = x + self.drop_path(self.mlp(x))
542
+ return x
543
+
544
+ def extra_repr(self) -> str:
545
+ return (
546
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
547
+ f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
548
+ )
549
+
550
+
551
+ class BasicLayer(nn.Module):
552
+ """A basic TinyViT layer for one stage.
553
+
554
+ Args:
555
+ dim (int): Number of input channels.
556
+ input_resolution (tuple[int]): Input resolution.
557
+ depth (int): Number of blocks.
558
+ num_heads (int): Number of attention heads.
559
+ window_size (int): Local window size.
560
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
561
+ drop (float, optional): Dropout rate. Default: 0.0
562
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
563
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
564
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
565
+ local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3
566
+ activation: the activation function. Default: nn.GELU
567
+ out_dim: the output dimension of the layer. Default: dim
568
+ """
569
+
570
+ def __init__(
571
+ self,
572
+ dim,
573
+ input_resolution,
574
+ depth,
575
+ num_heads,
576
+ window_size,
577
+ mlp_ratio=4.0,
578
+ drop=0.0,
579
+ drop_path=0.0,
580
+ downsample=None,
581
+ use_checkpoint=False,
582
+ local_conv_size=3,
583
+ activation=nn.GELU,
584
+ out_dim=None,
585
+ ):
586
+ super().__init__()
587
+ self.dim = dim
588
+ self.input_resolution = input_resolution
589
+ self.depth = depth
590
+ self.use_checkpoint = use_checkpoint
591
+
592
+ # build blocks
593
+ self.blocks = nn.ModuleList(
594
+ [
595
+ TinyViTBlock(
596
+ dim=dim,
597
+ input_resolution=input_resolution,
598
+ num_heads=num_heads,
599
+ window_size=window_size,
600
+ mlp_ratio=mlp_ratio,
601
+ drop=drop,
602
+ drop_path=drop_path[i]
603
+ if isinstance(drop_path, list)
604
+ else drop_path,
605
+ local_conv_size=local_conv_size,
606
+ activation=activation,
607
+ )
608
+ for i in range(depth)
609
+ ]
610
+ )
611
+
612
+ # patch merging layer
613
+ if downsample is not None:
614
+ self.downsample = downsample(
615
+ input_resolution, dim=dim, out_dim=out_dim, activation=activation
616
+ )
617
+ else:
618
+ self.downsample = None
619
+
620
+ def forward(self, x):
621
+ for blk in self.blocks:
622
+ if self.use_checkpoint:
623
+ x = checkpoint.checkpoint(blk, x)
624
+ else:
625
+ x = blk(x)
626
+ if self.downsample is not None:
627
+ x = self.downsample(x)
628
+ return x
629
+
630
+ def extra_repr(self) -> str:
631
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
632
+
633
+
634
+ class LayerNorm2d(nn.Module):
635
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
636
+ super().__init__()
637
+ self.weight = nn.Parameter(torch.ones(num_channels))
638
+ self.bias = nn.Parameter(torch.zeros(num_channels))
639
+ self.eps = eps
640
+
641
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
642
+ u = x.mean(1, keepdim=True)
643
+ s = (x - u).pow(2).mean(1, keepdim=True)
644
+ x = (x - u) / torch.sqrt(s + self.eps)
645
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
646
+ return x
647
+
648
+
649
+ class TinyViT(nn.Module):
650
+ def __init__(
651
+ self,
652
+ img_size=224,
653
+ in_chans=3,
654
+ num_classes=1000,
655
+ embed_dims=[96, 192, 384, 768],
656
+ depths=[2, 2, 6, 2],
657
+ num_heads=[3, 6, 12, 24],
658
+ window_sizes=[7, 7, 14, 7],
659
+ mlp_ratio=4.0,
660
+ drop_rate=0.0,
661
+ drop_path_rate=0.1,
662
+ use_checkpoint=False,
663
+ mbconv_expand_ratio=4.0,
664
+ local_conv_size=3,
665
+ layer_lr_decay=1.0,
666
+ ):
667
+ super().__init__()
668
+ self.img_size = img_size
669
+ self.num_classes = num_classes
670
+ self.depths = depths
671
+ self.num_layers = len(depths)
672
+ self.mlp_ratio = mlp_ratio
673
+
674
+ activation = nn.GELU
675
+
676
+ self.patch_embed = PatchEmbed(
677
+ in_chans=in_chans,
678
+ embed_dim=embed_dims[0],
679
+ resolution=img_size,
680
+ activation=activation,
681
+ )
682
+
683
+ patches_resolution = self.patch_embed.patches_resolution
684
+ self.patches_resolution = patches_resolution
685
+
686
+ # stochastic depth
687
+ dpr = [
688
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
689
+ ] # stochastic depth decay rule
690
+
691
+ # build layers
692
+ self.layers = nn.ModuleList()
693
+ for i_layer in range(self.num_layers):
694
+ kwargs = dict(
695
+ dim=embed_dims[i_layer],
696
+ input_resolution=(
697
+ patches_resolution[0]
698
+ // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
699
+ patches_resolution[1]
700
+ // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
701
+ ),
702
+ # input_resolution=(patches_resolution[0] // (2 ** i_layer),
703
+ # patches_resolution[1] // (2 ** i_layer)),
704
+ depth=depths[i_layer],
705
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
706
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
707
+ use_checkpoint=use_checkpoint,
708
+ out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)],
709
+ activation=activation,
710
+ )
711
+ if i_layer == 0:
712
+ layer = ConvLayer(
713
+ conv_expand_ratio=mbconv_expand_ratio,
714
+ **kwargs,
715
+ )
716
+ else:
717
+ layer = BasicLayer(
718
+ num_heads=num_heads[i_layer],
719
+ window_size=window_sizes[i_layer],
720
+ mlp_ratio=self.mlp_ratio,
721
+ drop=drop_rate,
722
+ local_conv_size=local_conv_size,
723
+ **kwargs,
724
+ )
725
+ self.layers.append(layer)
726
+
727
+ # Classifier head
728
+ self.norm_head = nn.LayerNorm(embed_dims[-1])
729
+ self.head = (
730
+ nn.Linear(embed_dims[-1], num_classes)
731
+ if num_classes > 0
732
+ else torch.nn.Identity()
733
+ )
734
+
735
+ # init weights
736
+ self.apply(self._init_weights)
737
+ self.set_layer_lr_decay(layer_lr_decay)
738
+ self.neck = nn.Sequential(
739
+ nn.Conv2d(
740
+ embed_dims[-1],
741
+ 256,
742
+ kernel_size=1,
743
+ bias=False,
744
+ ),
745
+ LayerNorm2d(256),
746
+ nn.Conv2d(
747
+ 256,
748
+ 256,
749
+ kernel_size=3,
750
+ padding=1,
751
+ bias=False,
752
+ ),
753
+ LayerNorm2d(256),
754
+ )
755
+
756
+ def set_layer_lr_decay(self, layer_lr_decay):
757
+ decay_rate = layer_lr_decay
758
+
759
+ # layers -> blocks (depth)
760
+ depth = sum(self.depths)
761
+ lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
762
+ # print("LR SCALES:", lr_scales)
763
+
764
+ def _set_lr_scale(m, scale):
765
+ for p in m.parameters():
766
+ p.lr_scale = scale
767
+
768
+ self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
769
+ i = 0
770
+ for layer in self.layers:
771
+ for block in layer.blocks:
772
+ block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
773
+ i += 1
774
+ if layer.downsample is not None:
775
+ layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1]))
776
+ assert i == depth
777
+ for m in [self.norm_head, self.head]:
778
+ m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
779
+
780
+ for k, p in self.named_parameters():
781
+ p.param_name = k
782
+
783
+ def _check_lr_scale(m):
784
+ for p in m.parameters():
785
+ assert hasattr(p, "lr_scale"), p.param_name
786
+
787
+ self.apply(_check_lr_scale)
788
+
789
+ def _init_weights(self, m):
790
+ if isinstance(m, nn.Linear):
791
+ trunc_normal_(m.weight, std=0.02)
792
+ if isinstance(m, nn.Linear) and m.bias is not None:
793
+ nn.init.constant_(m.bias, 0)
794
+ elif isinstance(m, nn.LayerNorm):
795
+ nn.init.constant_(m.bias, 0)
796
+ nn.init.constant_(m.weight, 1.0)
797
+
798
+ @torch.jit.ignore
799
+ def no_weight_decay_keywords(self):
800
+ return {"attention_biases"}
801
+
802
+ def forward_features(self, x):
803
+ # x: (N, C, H, W)
804
+ x = self.patch_embed(x)
805
+
806
+ x = self.layers[0](x)
807
+ start_i = 1
808
+
809
+ for i in range(start_i, len(self.layers)):
810
+ layer = self.layers[i]
811
+ x = layer(x)
812
+ B, _, C = x.size()
813
+ x = x.view(B, 64, 64, C)
814
+ x = x.permute(0, 3, 1, 2)
815
+ x = self.neck(x)
816
+ return x
817
+
818
+ def forward(self, x):
819
+ x = self.forward_features(x)
820
+ # x = self.norm_head(x)
821
+ # x = self.head(x)
822
+ return x
iopaint/plugins/segment_anything/modeling/transformer.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+ import math
11
+ from typing import Tuple, Type
12
+
13
+ from .common import MLPBlock
14
+
15
+
16
+ class TwoWayTransformer(nn.Module):
17
+ def __init__(
18
+ self,
19
+ depth: int,
20
+ embedding_dim: int,
21
+ num_heads: int,
22
+ mlp_dim: int,
23
+ activation: Type[nn.Module] = nn.ReLU,
24
+ attention_downsample_rate: int = 2,
25
+ ) -> None:
26
+ """
27
+ A transformer decoder that attends to an input image using
28
+ queries whose positional embedding is supplied.
29
+
30
+ Args:
31
+ depth (int): number of layers in the transformer
32
+ embedding_dim (int): the channel dimension for the input embeddings
33
+ num_heads (int): the number of heads for multihead attention. Must
34
+ divide embedding_dim
35
+ mlp_dim (int): the channel dimension internal to the MLP block
36
+ activation (nn.Module): the activation to use in the MLP block
37
+ """
38
+ super().__init__()
39
+ self.depth = depth
40
+ self.embedding_dim = embedding_dim
41
+ self.num_heads = num_heads
42
+ self.mlp_dim = mlp_dim
43
+ self.layers = nn.ModuleList()
44
+
45
+ for i in range(depth):
46
+ self.layers.append(
47
+ TwoWayAttentionBlock(
48
+ embedding_dim=embedding_dim,
49
+ num_heads=num_heads,
50
+ mlp_dim=mlp_dim,
51
+ activation=activation,
52
+ attention_downsample_rate=attention_downsample_rate,
53
+ skip_first_layer_pe=(i == 0),
54
+ )
55
+ )
56
+
57
+ self.final_attn_token_to_image = Attention(
58
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
59
+ )
60
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
61
+
62
+ def forward(
63
+ self,
64
+ image_embedding: Tensor,
65
+ image_pe: Tensor,
66
+ point_embedding: Tensor,
67
+ ) -> Tuple[Tensor, Tensor]:
68
+ """
69
+ Args:
70
+ image_embedding (torch.Tensor): image to attend to. Should be shape
71
+ B x embedding_dim x h x w for any h and w.
72
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
73
+ have the same shape as image_embedding.
74
+ point_embedding (torch.Tensor): the embedding to add to the query points.
75
+ Must have shape B x N_points x embedding_dim for any N_points.
76
+
77
+ Returns:
78
+ torch.Tensor: the processed point_embedding
79
+ torch.Tensor: the processed image_embedding
80
+ """
81
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
82
+ bs, c, h, w = image_embedding.shape
83
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
84
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
85
+
86
+ # Prepare queries
87
+ queries = point_embedding
88
+ keys = image_embedding
89
+
90
+ # Apply transformer blocks and final layernorm
91
+ for layer in self.layers:
92
+ queries, keys = layer(
93
+ queries=queries,
94
+ keys=keys,
95
+ query_pe=point_embedding,
96
+ key_pe=image_pe,
97
+ )
98
+
99
+ # Apply the final attenion layer from the points to the image
100
+ q = queries + point_embedding
101
+ k = keys + image_pe
102
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
103
+ queries = queries + attn_out
104
+ queries = self.norm_final_attn(queries)
105
+
106
+ return queries, keys
107
+
108
+
109
+ class TwoWayAttentionBlock(nn.Module):
110
+ def __init__(
111
+ self,
112
+ embedding_dim: int,
113
+ num_heads: int,
114
+ mlp_dim: int = 2048,
115
+ activation: Type[nn.Module] = nn.ReLU,
116
+ attention_downsample_rate: int = 2,
117
+ skip_first_layer_pe: bool = False,
118
+ ) -> None:
119
+ """
120
+ A transformer block with four layers: (1) self-attention of sparse
121
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
122
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
123
+ inputs.
124
+
125
+ Arguments:
126
+ embedding_dim (int): the channel dimension of the embeddings
127
+ num_heads (int): the number of heads in the attention layers
128
+ mlp_dim (int): the hidden dimension of the mlp block
129
+ activation (nn.Module): the activation of the mlp block
130
+ skip_first_layer_pe (bool): skip the PE on the first layer
131
+ """
132
+ super().__init__()
133
+ self.self_attn = Attention(embedding_dim, num_heads)
134
+ self.norm1 = nn.LayerNorm(embedding_dim)
135
+
136
+ self.cross_attn_token_to_image = Attention(
137
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
138
+ )
139
+ self.norm2 = nn.LayerNorm(embedding_dim)
140
+
141
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
142
+ self.norm3 = nn.LayerNorm(embedding_dim)
143
+
144
+ self.norm4 = nn.LayerNorm(embedding_dim)
145
+ self.cross_attn_image_to_token = Attention(
146
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
147
+ )
148
+
149
+ self.skip_first_layer_pe = skip_first_layer_pe
150
+
151
+ def forward(
152
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
153
+ ) -> Tuple[Tensor, Tensor]:
154
+ # Self attention block
155
+ if self.skip_first_layer_pe:
156
+ queries = self.self_attn(q=queries, k=queries, v=queries)
157
+ else:
158
+ q = queries + query_pe
159
+ attn_out = self.self_attn(q=q, k=q, v=queries)
160
+ queries = queries + attn_out
161
+ queries = self.norm1(queries)
162
+
163
+ # Cross attention block, tokens attending to image embedding
164
+ q = queries + query_pe
165
+ k = keys + key_pe
166
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
167
+ queries = queries + attn_out
168
+ queries = self.norm2(queries)
169
+
170
+ # MLP block
171
+ mlp_out = self.mlp(queries)
172
+ queries = queries + mlp_out
173
+ queries = self.norm3(queries)
174
+
175
+ # Cross attention block, image embedding attending to tokens
176
+ q = queries + query_pe
177
+ k = keys + key_pe
178
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
179
+ keys = keys + attn_out
180
+ keys = self.norm4(keys)
181
+
182
+ return queries, keys
183
+
184
+
185
+ class Attention(nn.Module):
186
+ """
187
+ An attention layer that allows for downscaling the size of the embedding
188
+ after projection to queries, keys, and values.
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ embedding_dim: int,
194
+ num_heads: int,
195
+ downsample_rate: int = 1,
196
+ ) -> None:
197
+ super().__init__()
198
+ self.embedding_dim = embedding_dim
199
+ self.internal_dim = embedding_dim // downsample_rate
200
+ self.num_heads = num_heads
201
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
202
+
203
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
204
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
205
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
206
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
207
+
208
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
209
+ b, n, c = x.shape
210
+ x = x.reshape(b, n, num_heads, c // num_heads)
211
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
212
+
213
+ def _recombine_heads(self, x: Tensor) -> Tensor:
214
+ b, n_heads, n_tokens, c_per_head = x.shape
215
+ x = x.transpose(1, 2)
216
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
217
+
218
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
219
+ # Input projections
220
+ q = self.q_proj(q)
221
+ k = self.k_proj(k)
222
+ v = self.v_proj(v)
223
+
224
+ # Separate into heads
225
+ q = self._separate_heads(q, self.num_heads)
226
+ k = self._separate_heads(k, self.num_heads)
227
+ v = self._separate_heads(v, self.num_heads)
228
+
229
+ # Attention
230
+ _, _, _, c_per_head = q.shape
231
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
232
+ attn = attn / math.sqrt(c_per_head)
233
+ attn = torch.softmax(attn, dim=-1)
234
+
235
+ # Get output
236
+ out = attn @ v
237
+ out = self._recombine_heads(out)
238
+ out = self.out_proj(out)
239
+
240
+ return out
iopaint/plugins/segment_anything/utils/transforms.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.nn import functional as F
10
+ from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11
+
12
+ from copy import deepcopy
13
+ from typing import Tuple
14
+
15
+
16
+ class ResizeLongestSide:
17
+ """
18
+ Resizes images to longest side 'target_length', as well as provides
19
+ methods for resizing coordinates and boxes. Provides methods for
20
+ transforming both numpy array and batched torch tensors.
21
+ """
22
+
23
+ def __init__(self, target_length: int) -> None:
24
+ self.target_length = target_length
25
+
26
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
27
+ """
28
+ Expects a numpy array with shape HxWxC in uint8 format.
29
+ """
30
+ target_size = self.get_preprocess_shape(
31
+ image.shape[0], image.shape[1], self.target_length
32
+ )
33
+ return np.array(resize(to_pil_image(image), target_size))
34
+
35
+ def apply_coords(
36
+ self, coords: np.ndarray, original_size: Tuple[int, ...]
37
+ ) -> np.ndarray:
38
+ """
39
+ Expects a numpy array of length 2 in the final dimension. Requires the
40
+ original image size in (H, W) format.
41
+ """
42
+ old_h, old_w = original_size
43
+ new_h, new_w = self.get_preprocess_shape(
44
+ original_size[0], original_size[1], self.target_length
45
+ )
46
+ coords = deepcopy(coords).astype(float)
47
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
48
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
49
+ return coords
50
+
51
+ def apply_boxes(
52
+ self, boxes: np.ndarray, original_size: Tuple[int, ...]
53
+ ) -> np.ndarray:
54
+ """
55
+ Expects a numpy array shape Bx4. Requires the original image size
56
+ in (H, W) format.
57
+ """
58
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
59
+ return boxes.reshape(-1, 4)
60
+
61
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
62
+ """
63
+ Expects batched images with shape BxCxHxW and float format. This
64
+ transformation may not exactly match apply_image. apply_image is
65
+ the transformation expected by the model.
66
+ """
67
+ # Expects an image in BCHW format. May not exactly match apply_image.
68
+ target_size = self.get_preprocess_shape(
69
+ image.shape[0], image.shape[1], self.target_length
70
+ )
71
+ return F.interpolate(
72
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
73
+ )
74
+
75
+ def apply_coords_torch(
76
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
77
+ ) -> torch.Tensor:
78
+ """
79
+ Expects a torch tensor with length 2 in the last dimension. Requires the
80
+ original image size in (H, W) format.
81
+ """
82
+ old_h, old_w = original_size
83
+ new_h, new_w = self.get_preprocess_shape(
84
+ original_size[0], original_size[1], self.target_length
85
+ )
86
+ coords = deepcopy(coords).to(torch.float)
87
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
88
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
89
+ return coords
90
+
91
+ def apply_boxes_torch(
92
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
93
+ ) -> torch.Tensor:
94
+ """
95
+ Expects a torch tensor with shape Bx4. Requires the original image
96
+ size in (H, W) format.
97
+ """
98
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
99
+ return boxes.reshape(-1, 4)
100
+
101
+ @staticmethod
102
+ def get_preprocess_shape(
103
+ oldh: int, oldw: int, long_side_length: int
104
+ ) -> Tuple[int, int]:
105
+ """
106
+ Compute the output size given input size and target long side length.
107
+ """
108
+ scale = long_side_length * 1.0 / max(oldh, oldw)
109
+ newh, neww = oldh * scale, oldw * scale
110
+ neww = int(neww + 0.5)
111
+ newh = int(newh + 0.5)
112
+ return (newh, neww)
iopaint/tests/test_sdxl.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from iopaint.tests.utils import check_device, current_dir
4
+
5
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
6
+
7
+ import pytest
8
+ import torch
9
+
10
+ from iopaint.model_manager import ModelManager
11
+ from iopaint.schema import HDStrategy, SDSampler, FREEUConfig
12
+ from iopaint.tests.test_model import get_config, assert_equal
13
+
14
+
15
+ @pytest.mark.parametrize("device", ["cuda", "mps"])
16
+ @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
17
+ @pytest.mark.parametrize("sampler", [SDSampler.ddim])
18
+ def test_sdxl(device, strategy, sampler):
19
+ sd_steps = check_device(device)
20
+
21
+ model = ModelManager(
22
+ name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
23
+ device=torch.device(device),
24
+ disable_nsfw=True,
25
+ sd_cpu_textencoder=False,
26
+ )
27
+ cfg = get_config(
28
+ strategy=strategy,
29
+ prompt="face of a fox, sitting on a bench",
30
+ sd_steps=sd_steps,
31
+ sd_strength=1.0,
32
+ sd_guidance_scale=7.0,
33
+ )
34
+ cfg.sd_sampler = sampler
35
+
36
+ assert_equal(
37
+ model,
38
+ cfg,
39
+ f"sdxl_device_{device}.png",
40
+ img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
41
+ mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
42
+ fx=2,
43
+ fy=2,
44
+ )
45
+
46
+
47
+ @pytest.mark.parametrize("device", ["cuda", "cpu"])
48
+ @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
49
+ @pytest.mark.parametrize("sampler", [SDSampler.ddim])
50
+ def test_sdxl_cpu_text_encoder(device, strategy, sampler):
51
+ sd_steps = check_device(device)
52
+
53
+ model = ModelManager(
54
+ name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
55
+ device=torch.device(device),
56
+ disable_nsfw=True,
57
+ sd_cpu_textencoder=True,
58
+ )
59
+ cfg = get_config(
60
+ strategy=strategy,
61
+ prompt="face of a fox, sitting on a bench",
62
+ sd_steps=sd_steps,
63
+ sd_strength=1.0,
64
+ sd_guidance_scale=7.0,
65
+ )
66
+ cfg.sd_sampler = sampler
67
+
68
+ assert_equal(
69
+ model,
70
+ cfg,
71
+ f"sdxl_device_{device}.png",
72
+ img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
73
+ mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
74
+ fx=2,
75
+ fy=2,
76
+ )
77
+
78
+
79
+ @pytest.mark.parametrize("device", ["cuda", "mps"])
80
+ @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
81
+ @pytest.mark.parametrize("sampler", [SDSampler.ddim])
82
+ def test_sdxl_lcm_lora_and_freeu(device, strategy, sampler):
83
+ sd_steps = check_device(device)
84
+
85
+ model = ModelManager(
86
+ name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
87
+ device=torch.device(device),
88
+ disable_nsfw=True,
89
+ sd_cpu_textencoder=False,
90
+ )
91
+ cfg = get_config(
92
+ strategy=strategy,
93
+ prompt="face of a fox, sitting on a bench",
94
+ sd_steps=sd_steps,
95
+ sd_strength=1.0,
96
+ sd_guidance_scale=2.0,
97
+ sd_lcm_lora=True,
98
+ )
99
+ cfg.sd_sampler = sampler
100
+
101
+ name = f"device_{device}_{sampler}"
102
+
103
+ assert_equal(
104
+ model,
105
+ cfg,
106
+ f"sdxl_{name}_lcm_lora.png",
107
+ img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
108
+ mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
109
+ fx=2,
110
+ fy=2,
111
+ )
112
+
113
+ cfg = get_config(
114
+ strategy=strategy,
115
+ prompt="face of a fox, sitting on a bench",
116
+ sd_steps=sd_steps,
117
+ sd_guidance_scale=7.5,
118
+ sd_freeu=True,
119
+ sd_freeu_config=FREEUConfig(),
120
+ )
121
+
122
+ assert_equal(
123
+ model,
124
+ cfg,
125
+ f"sdxl_{name}_freeu_device_{device}.png",
126
+ img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
127
+ mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
128
+ fx=2,
129
+ fy=2,
130
+ )
131
+
132
+
133
+ @pytest.mark.parametrize("device", ["cuda", "mps"])
134
+ @pytest.mark.parametrize(
135
+ "rect",
136
+ [
137
+ [-128, -128, 1024, 1024],
138
+ ],
139
+ )
140
+ def test_sdxl_outpainting(device, rect):
141
+ sd_steps = check_device(device)
142
+
143
+ model = ModelManager(
144
+ name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
145
+ device=torch.device(device),
146
+ disable_nsfw=True,
147
+ sd_cpu_textencoder=False,
148
+ )
149
+
150
+ cfg = get_config(
151
+ strategy=HDStrategy.ORIGINAL,
152
+ prompt="a dog sitting on a bench in the park",
153
+ sd_steps=sd_steps,
154
+ use_extender=True,
155
+ extender_x=rect[0],
156
+ extender_y=rect[1],
157
+ extender_width=rect[2],
158
+ extender_height=rect[3],
159
+ sd_strength=1.0,
160
+ sd_guidance_scale=8.0,
161
+ sd_sampler=SDSampler.ddim,
162
+ )
163
+
164
+ assert_equal(
165
+ model,
166
+ cfg,
167
+ f"sdxl_outpainting_dog_ddim_{'_'.join(map(str, rect))}_device_{device}.png",
168
+ img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
169
+ mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
170
+ fx=1.5,
171
+ fy=1.5,
172
+ )
iopaint/tests/utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import cv2
3
+ import pytest
4
+ import torch
5
+
6
+ from iopaint.helper import encode_pil_to_base64
7
+ from iopaint.schema import LDMSampler, HDStrategy, InpaintRequest, SDSampler
8
+ from PIL import Image
9
+
10
+ current_dir = Path(__file__).parent.absolute().resolve()
11
+ save_dir = current_dir / "result"
12
+ save_dir.mkdir(exist_ok=True, parents=True)
13
+
14
+
15
+ def check_device(device: str) -> int:
16
+ if device == "cuda" and not torch.cuda.is_available():
17
+ pytest.skip("CUDA is not available, skip test on cuda")
18
+ if device == "mps" and not torch.backends.mps.is_available():
19
+ pytest.skip("mps is not available, skip test on mps")
20
+ steps = 2 if device == "cpu" else 20
21
+ return steps
22
+
23
+
24
+ def assert_equal(
25
+ model,
26
+ config: InpaintRequest,
27
+ gt_name,
28
+ fx: float = 1,
29
+ fy: float = 1,
30
+ img_p=current_dir / "image.png",
31
+ mask_p=current_dir / "mask.png",
32
+ ):
33
+ img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
34
+ print(f"Input image shape: {img.shape}")
35
+ res = model(img, mask, config)
36
+ ok = cv2.imwrite(
37
+ str(save_dir / gt_name),
38
+ res,
39
+ [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
40
+ )
41
+ assert ok, save_dir / gt_name
42
+
43
+ """
44
+ Note that JPEG is lossy compression, so even if it is the highest quality 100,
45
+ when the saved images is reloaded, a difference occurs with the original pixel value.
46
+ If you want to save the original images as it is, save it as PNG or BMP.
47
+ """
48
+ # gt = cv2.imread(str(current_dir / gt_name), cv2.IMREAD_UNCHANGED)
49
+ # assert np.array_equal(res, gt)
50
+
51
+
52
+ def get_data(
53
+ fx: float = 1,
54
+ fy: float = 1.0,
55
+ img_p=current_dir / "image.png",
56
+ mask_p=current_dir / "mask.png",
57
+ ):
58
+ img = cv2.imread(str(img_p))
59
+ img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
60
+ mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
61
+ img = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
62
+ mask = cv2.resize(mask, None, fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST)
63
+ return img, mask
64
+
65
+
66
+ def get_config(**kwargs):
67
+ data = dict(
68
+ sd_sampler=kwargs.get("sd_sampler", SDSampler.uni_pc),
69
+ ldm_steps=1,
70
+ ldm_sampler=LDMSampler.plms,
71
+ hd_strategy=kwargs.get("strategy", HDStrategy.ORIGINAL),
72
+ hd_strategy_crop_margin=32,
73
+ hd_strategy_crop_trigger_size=200,
74
+ hd_strategy_resize_limit=200,
75
+ )
76
+ data.update(**kwargs)
77
+ return InpaintRequest(image="", mask="", **data)
iopaint/web_config.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+
5
+ from iopaint.schema import (
6
+ Device,
7
+ InteractiveSegModel,
8
+ RemoveBGModel,
9
+ RealESRGANModel,
10
+ ApiConfig,
11
+ )
12
+
13
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
14
+
15
+ from datetime import datetime
16
+ from json import JSONDecodeError
17
+
18
+ import gradio as gr
19
+ from iopaint.download import scan_models
20
+ from loguru import logger
21
+
22
+ from iopaint.const import *
23
+
24
+
25
+ _config_file: Path = None
26
+
27
+
28
+ default_configs = dict(
29
+ host="127.0.0.1",
30
+ port=8080,
31
+ inbrowser=True,
32
+ model=DEFAULT_MODEL,
33
+ model_dir=DEFAULT_MODEL_DIR,
34
+ no_half=False,
35
+ low_mem=False,
36
+ cpu_offload=False,
37
+ disable_nsfw_checker=False,
38
+ local_files_only=False,
39
+ cpu_textencoder=False,
40
+ device=Device.cuda,
41
+ input=None,
42
+ output_dir=None,
43
+ quality=95,
44
+ enable_interactive_seg=False,
45
+ interactive_seg_model=InteractiveSegModel.vit_b,
46
+ interactive_seg_device=Device.cpu,
47
+ enable_remove_bg=False,
48
+ remove_bg_model=RemoveBGModel.briaai_rmbg_1_4,
49
+ enable_anime_seg=False,
50
+ enable_realesrgan=False,
51
+ realesrgan_device=Device.cpu,
52
+ realesrgan_model=RealESRGANModel.realesr_general_x4v3,
53
+ enable_gfpgan=False,
54
+ gfpgan_device=Device.cpu,
55
+ enable_restoreformer=False,
56
+ restoreformer_device=Device.cpu,
57
+ )
58
+
59
+
60
+ class WebConfig(ApiConfig):
61
+ model_dir: str = DEFAULT_MODEL_DIR
62
+
63
+
64
+ def load_config(p: Path) -> WebConfig:
65
+ if p.exists():
66
+ with open(p, "r", encoding="utf-8") as f:
67
+ try:
68
+ return WebConfig(**{**default_configs, **json.load(f)})
69
+ except JSONDecodeError:
70
+ print(f"Load config file failed, using default configs")
71
+ return WebConfig(**default_configs)
72
+ else:
73
+ return WebConfig(**default_configs)
74
+
75
+
76
+ def save_config(
77
+ host,
78
+ port,
79
+ model,
80
+ model_dir,
81
+ no_half,
82
+ low_mem,
83
+ cpu_offload,
84
+ disable_nsfw_checker,
85
+ local_files_only,
86
+ cpu_textencoder,
87
+ device,
88
+ input,
89
+ output_dir,
90
+ quality,
91
+ enable_interactive_seg,
92
+ interactive_seg_model,
93
+ interactive_seg_device,
94
+ enable_remove_bg,
95
+ remove_bg_model,
96
+ enable_anime_seg,
97
+ enable_realesrgan,
98
+ realesrgan_device,
99
+ realesrgan_model,
100
+ enable_gfpgan,
101
+ gfpgan_device,
102
+ enable_restoreformer,
103
+ restoreformer_device,
104
+ inbrowser,
105
+ ):
106
+ config = WebConfig(**locals())
107
+ if str(config.input) == ".":
108
+ config.input = None
109
+ if str(config.output_dir) == ".":
110
+ config.output_dir = None
111
+ config.model = config.model.strip()
112
+ print(config.model_dump_json(indent=4))
113
+ if config.input and not os.path.exists(config.input):
114
+ return "[Error] Input file or directory does not exist"
115
+
116
+ current_time = datetime.now().strftime("%H:%M:%S")
117
+ msg = f"[{current_time}] Successful save config to: {str(_config_file.absolute())}"
118
+ logger.info(msg)
119
+ try:
120
+ with open(_config_file, "w", encoding="utf-8") as f:
121
+ f.write(config.model_dump_json(indent=4))
122
+ except Exception as e:
123
+ return f"Save configure file failed: {str(e)}"
124
+ return msg
125
+
126
+
127
+ def change_current_model(new_model):
128
+ return new_model
129
+
130
+
131
+ def main(config_file: Path):
132
+ global _config_file
133
+ _config_file = config_file
134
+
135
+ init_config = load_config(config_file)
136
+ downloaded_models = [it.name for it in scan_models()]
137
+
138
+ with gr.Blocks() as demo:
139
+ with gr.Row():
140
+ with gr.Column():
141
+ gr.Textbox(config_file, label="Config file", interactive=False)
142
+ with gr.Column():
143
+ save_btn = gr.Button(value="Save configurations")
144
+ message = gr.HTML()
145
+
146
+ with gr.Tabs():
147
+ with gr.Tab("Common"):
148
+ with gr.Row():
149
+ host = gr.Textbox(init_config.host, label="Host")
150
+ port = gr.Number(init_config.port, label="Port", precision=0)
151
+ inbrowser = gr.Checkbox(init_config.inbrowser, label=INBROWSER_HELP)
152
+
153
+ with gr.Column():
154
+ model = gr.Textbox(
155
+ init_config.model,
156
+ label="Current Model. This is the model that will be used when the service starts. "
157
+ "If the model has not been downloaded before, it will be automatically downloaded. "
158
+ "You can select a model from the dropdown box below or manually enter the SD/SDXL model ID from HuggingFace, for example, runwayml/stable-diffusion-inpainting.",
159
+ )
160
+ with gr.Row():
161
+ recommend_model = gr.Dropdown(
162
+ ["lama", "mat", "migan"] + DIFFUSION_MODELS,
163
+ label="Recommended Models",
164
+ )
165
+ downloaded_model = gr.Dropdown(
166
+ downloaded_models, label="Downloaded Models"
167
+ )
168
+
169
+ device = gr.Radio(
170
+ Device.values(), label="Device", value=init_config.device
171
+ )
172
+ quality = gr.Slider(
173
+ value=95,
174
+ label=f"Image Quality ({QUALITY_HELP})",
175
+ minimum=75,
176
+ maximum=100,
177
+ step=1,
178
+ )
179
+
180
+ no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}")
181
+ cpu_offload = gr.Checkbox(
182
+ init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}"
183
+ )
184
+ low_mem = gr.Checkbox(init_config.low_mem, label=f"{LOW_MEM_HELP}")
185
+ cpu_textencoder = gr.Checkbox(
186
+ init_config.cpu_textencoder, label=f"{CPU_TEXTENCODER_HELP}"
187
+ )
188
+ disable_nsfw_checker = gr.Checkbox(
189
+ init_config.disable_nsfw_checker, label=f"{DISABLE_NSFW_HELP}"
190
+ )
191
+ local_files_only = gr.Checkbox(
192
+ init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}"
193
+ )
194
+
195
+ with gr.Column():
196
+ model_dir = gr.Textbox(
197
+ init_config.model_dir, label=f"{MODEL_DIR_HELP}"
198
+ )
199
+ input = gr.Textbox(
200
+ init_config.input,
201
+ label=f"Input file or directory. {INPUT_HELP}",
202
+ )
203
+ output_dir = gr.Textbox(
204
+ init_config.output_dir,
205
+ label=f"Output directory. {OUTPUT_DIR_HELP}",
206
+ )
207
+
208
+ with gr.Tab("Plugins"):
209
+ with gr.Row():
210
+ enable_interactive_seg = gr.Checkbox(
211
+ init_config.enable_interactive_seg, label=INTERACTIVE_SEG_HELP
212
+ )
213
+ interactive_seg_model = gr.Radio(
214
+ InteractiveSegModel.values(),
215
+ label=f"Segment Anything models. {INTERACTIVE_SEG_MODEL_HELP}",
216
+ value=init_config.interactive_seg_model,
217
+ )
218
+ interactive_seg_device = gr.Radio(
219
+ Device.values(),
220
+ label="Segment Anything Device",
221
+ value=init_config.interactive_seg_device,
222
+ )
223
+ with gr.Row():
224
+ enable_remove_bg = gr.Checkbox(
225
+ init_config.enable_remove_bg, label=REMOVE_BG_HELP
226
+ )
227
+ remove_bg_model = gr.Radio(
228
+ RemoveBGModel.values(),
229
+ label="Remove bg model",
230
+ value=init_config.remove_bg_model,
231
+ )
232
+ with gr.Row():
233
+ enable_anime_seg = gr.Checkbox(
234
+ init_config.enable_anime_seg, label=ANIMESEG_HELP
235
+ )
236
+
237
+ with gr.Row():
238
+ enable_realesrgan = gr.Checkbox(
239
+ init_config.enable_realesrgan, label=REALESRGAN_HELP
240
+ )
241
+ realesrgan_device = gr.Radio(
242
+ Device.values(),
243
+ label="RealESRGAN Device",
244
+ value=init_config.realesrgan_device,
245
+ )
246
+ realesrgan_model = gr.Radio(
247
+ RealESRGANModel.values(),
248
+ label="RealESRGAN model",
249
+ value=init_config.realesrgan_model,
250
+ )
251
+ with gr.Row():
252
+ enable_gfpgan = gr.Checkbox(
253
+ init_config.enable_gfpgan, label=GFPGAN_HELP
254
+ )
255
+ gfpgan_device = gr.Radio(
256
+ Device.values(),
257
+ label="GFPGAN Device",
258
+ value=init_config.gfpgan_device,
259
+ )
260
+ with gr.Row():
261
+ enable_restoreformer = gr.Checkbox(
262
+ init_config.enable_restoreformer, label=RESTOREFORMER_HELP
263
+ )
264
+ restoreformer_device = gr.Radio(
265
+ Device.values(),
266
+ label="RestoreFormer Device",
267
+ value=init_config.restoreformer_device,
268
+ )
269
+
270
+ downloaded_model.change(change_current_model, [downloaded_model], model)
271
+ recommend_model.change(change_current_model, [recommend_model], model)
272
+
273
+ save_btn.click(
274
+ save_config,
275
+ [
276
+ host,
277
+ port,
278
+ model,
279
+ model_dir,
280
+ no_half,
281
+ low_mem,
282
+ cpu_offload,
283
+ disable_nsfw_checker,
284
+ local_files_only,
285
+ cpu_textencoder,
286
+ device,
287
+ input,
288
+ output_dir,
289
+ quality,
290
+ enable_interactive_seg,
291
+ interactive_seg_model,
292
+ interactive_seg_device,
293
+ enable_remove_bg,
294
+ remove_bg_model,
295
+ enable_anime_seg,
296
+ enable_realesrgan,
297
+ realesrgan_device,
298
+ realesrgan_model,
299
+ enable_gfpgan,
300
+ gfpgan_device,
301
+ enable_restoreformer,
302
+ restoreformer_device,
303
+ inbrowser,
304
+ ],
305
+ message,
306
+ )
307
+ demo.launch(inbrowser=True, show_api=False)
pretrained-model/version.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 1
pretrained-model/version_diffusers_cache.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 1
utils/tools.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import yaml
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ import torch.nn.functional as F
8
+
9
+
10
+ def pil_loader(path):
11
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
12
+ with open(path, 'rb') as f:
13
+ img = Image.open(f)
14
+ return img.convert('RGB')
15
+
16
+
17
+ def default_loader(path):
18
+ return pil_loader(path)
19
+
20
+
21
+ def tensor_img_to_npimg(tensor_img):
22
+ """
23
+ Turn a tensor image with shape CxHxW to a numpy array image with shape HxWxC
24
+ :param tensor_img:
25
+ :return: a numpy array image with shape HxWxC
26
+ """
27
+ if not (torch.is_tensor(tensor_img) and tensor_img.ndimension() == 3):
28
+ raise NotImplementedError("Not supported tensor image. Only tensors with dimension CxHxW are supported.")
29
+ npimg = np.transpose(tensor_img.numpy(), (1, 2, 0))
30
+ npimg = npimg.squeeze()
31
+ assert isinstance(npimg, np.ndarray) and (npimg.ndim in {2, 3})
32
+ return npimg
33
+
34
+
35
+ # Change the values of tensor x from range [0, 1] to [-1, 1]
36
+ def normalize(x):
37
+ return x.mul_(2).add_(-1)
38
+
39
+ def same_padding(images, ksizes, strides, rates):
40
+ assert len(images.size()) == 4
41
+ batch_size, channel, rows, cols = images.size()
42
+ out_rows = (rows + strides[0] - 1) // strides[0]
43
+ out_cols = (cols + strides[1] - 1) // strides[1]
44
+ effective_k_row = (ksizes[0] - 1) * rates[0] + 1
45
+ effective_k_col = (ksizes[1] - 1) * rates[1] + 1
46
+ padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
47
+ padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
48
+ # Pad the input
49
+ padding_top = int(padding_rows / 2.)
50
+ padding_left = int(padding_cols / 2.)
51
+ padding_bottom = padding_rows - padding_top
52
+ padding_right = padding_cols - padding_left
53
+ paddings = (padding_left, padding_right, padding_top, padding_bottom)
54
+ images = torch.nn.ZeroPad2d(paddings)(images)
55
+ return images
56
+
57
+
58
+ def extract_image_patches(images, ksizes, strides, rates, padding='same'):
59
+ """
60
+ Extract patches from images and put them in the C output dimension.
61
+ :param padding:
62
+ :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
63
+ :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
64
+ each dimension of images
65
+ :param strides: [stride_rows, stride_cols]
66
+ :param rates: [dilation_rows, dilation_cols]
67
+ :return: A Tensor
68
+ """
69
+ assert len(images.size()) == 4
70
+ assert padding in ['same', 'valid']
71
+ batch_size, channel, height, width = images.size()
72
+
73
+ if padding == 'same':
74
+ images = same_padding(images, ksizes, strides, rates)
75
+ elif padding == 'valid':
76
+ pass
77
+ else:
78
+ raise NotImplementedError('Unsupported padding type: {}.\
79
+ Only "same" or "valid" are supported.'.format(padding))
80
+
81
+ unfold = torch.nn.Unfold(kernel_size=ksizes,
82
+ dilation=rates,
83
+ padding=0,
84
+ stride=strides)
85
+ patches = unfold(images)
86
+ return patches # [N, C*k*k, L], L is the total number of such blocks
87
+
88
+
89
+ def random_bbox(config, batch_size):
90
+ """Generate a random tlhw with configuration.
91
+
92
+ Args:
93
+ config: Config should have configuration including img
94
+
95
+ Returns:
96
+ tuple: (top, left, height, width)
97
+
98
+ """
99
+ img_height, img_width, _ = config['image_shape']
100
+ h, w = config['mask_shape']
101
+ margin_height, margin_width = config['margin']
102
+ maxt = img_height - margin_height - h
103
+ maxl = img_width - margin_width - w
104
+ bbox_list = []
105
+ if config['mask_batch_same']:
106
+ t = np.random.randint(margin_height, maxt)
107
+ l = np.random.randint(margin_width, maxl)
108
+ bbox_list.append((t, l, h, w))
109
+ bbox_list = bbox_list * batch_size
110
+ else:
111
+ for i in range(batch_size):
112
+ t = np.random.randint(margin_height, maxt)
113
+ l = np.random.randint(margin_width, maxl)
114
+ bbox_list.append((t, l, h, w))
115
+
116
+ return torch.tensor(bbox_list, dtype=torch.int64)
117
+
118
+
119
+ def test_random_bbox():
120
+ image_shape = [256, 256, 3]
121
+ mask_shape = [128, 128]
122
+ margin = [0, 0]
123
+ bbox = random_bbox(image_shape)
124
+ return bbox
125
+
126
+
127
+ def bbox2mask(bboxes, height, width, max_delta_h, max_delta_w):
128
+ batch_size = bboxes.size(0)
129
+ mask = torch.zeros((batch_size, 1, height, width), dtype=torch.float32)
130
+ for i in range(batch_size):
131
+ bbox = bboxes[i]
132
+ delta_h = np.random.randint(max_delta_h // 2 + 1)
133
+ delta_w = np.random.randint(max_delta_w // 2 + 1)
134
+ mask[i, :, bbox[0] + delta_h:bbox[0] + bbox[2] - delta_h, bbox[1] + delta_w:bbox[1] + bbox[3] - delta_w] = 1.
135
+ return mask
136
+
137
+
138
+ def test_bbox2mask():
139
+ image_shape = [256, 256, 3]
140
+ mask_shape = [128, 128]
141
+ margin = [0, 0]
142
+ max_delta_shape = [32, 32]
143
+ bbox = random_bbox(image_shape)
144
+ mask = bbox2mask(bbox, image_shape[0], image_shape[1], max_delta_shape[0], max_delta_shape[1])
145
+ return mask
146
+
147
+
148
+ def local_patch(x, bbox_list):
149
+ assert len(x.size()) == 4
150
+ patches = []
151
+ for i, bbox in enumerate(bbox_list):
152
+ t, l, h, w = bbox
153
+ patches.append(x[i, :, t:t + h, l:l + w])
154
+ return torch.stack(patches, dim=0)
155
+
156
+
157
+ def mask_image(x, bboxes, config):
158
+ height, width, _ = config['image_shape']
159
+ max_delta_h, max_delta_w = config['max_delta_shape']
160
+ mask = bbox2mask(bboxes, height, width, max_delta_h, max_delta_w)
161
+ if x.is_cuda:
162
+ mask = mask.cuda()
163
+
164
+ if config['mask_type'] == 'hole':
165
+ result = x * (1. - mask)
166
+ elif config['mask_type'] == 'mosaic':
167
+ # TODO: Matching the mosaic patch size and the mask size
168
+ mosaic_unit_size = config['mosaic_unit_size']
169
+ downsampled_image = F.interpolate(x, scale_factor=1. / mosaic_unit_size, mode='nearest')
170
+ upsampled_image = F.interpolate(downsampled_image, size=(height, width), mode='nearest')
171
+ result = upsampled_image * mask + x * (1. - mask)
172
+ else:
173
+ raise NotImplementedError('Not implemented mask type.')
174
+
175
+ return result, mask
176
+
177
+
178
+ def spatial_discounting_mask(config):
179
+ """Generate spatial discounting mask constant.
180
+
181
+ Spatial discounting mask is first introduced in publication:
182
+ Generative Image Inpainting with Contextual Attention, Yu et al.
183
+
184
+ Args:
185
+ config: Config should have configuration including HEIGHT, WIDTH,
186
+ DISCOUNTED_MASK.
187
+
188
+ Returns:
189
+ tf.Tensor: spatial discounting mask
190
+
191
+ """
192
+ gamma = config['spatial_discounting_gamma']
193
+ height, width = config['mask_shape']
194
+ shape = [1, 1, height, width]
195
+ if config['discounted_mask']:
196
+ mask_values = np.ones((height, width))
197
+ for i in range(height):
198
+ for j in range(width):
199
+ mask_values[i, j] = max(
200
+ gamma ** min(i, height - i),
201
+ gamma ** min(j, width - j))
202
+ mask_values = np.expand_dims(mask_values, 0)
203
+ mask_values = np.expand_dims(mask_values, 0)
204
+ else:
205
+ mask_values = np.ones(shape)
206
+ spatial_discounting_mask_tensor = torch.tensor(mask_values, dtype=torch.float32)
207
+ if config['cuda']:
208
+ spatial_discounting_mask_tensor = spatial_discounting_mask_tensor.cuda()
209
+ return spatial_discounting_mask_tensor
210
+
211
+
212
+ def reduce_mean(x, axis=None, keepdim=False):
213
+ if not axis:
214
+ axis = range(len(x.shape))
215
+ for i in sorted(axis, reverse=True):
216
+ x = torch.mean(x, dim=i, keepdim=keepdim)
217
+ return x
218
+
219
+
220
+ def reduce_std(x, axis=None, keepdim=False):
221
+ if not axis:
222
+ axis = range(len(x.shape))
223
+ for i in sorted(axis, reverse=True):
224
+ x = torch.std(x, dim=i, keepdim=keepdim)
225
+ return x
226
+
227
+
228
+ def reduce_sum(x, axis=None, keepdim=False):
229
+ if not axis:
230
+ axis = range(len(x.shape))
231
+ for i in sorted(axis, reverse=True):
232
+ x = torch.sum(x, dim=i, keepdim=keepdim)
233
+ return x
234
+
235
+
236
+ def flow_to_image(flow):
237
+ """Transfer flow map to image.
238
+ Part of code forked from flownet.
239
+ """
240
+ out = []
241
+ maxu = -999.
242
+ maxv = -999.
243
+ minu = 999.
244
+ minv = 999.
245
+ maxrad = -1
246
+ for i in range(flow.shape[0]):
247
+ u = flow[i, :, :, 0]
248
+ v = flow[i, :, :, 1]
249
+ idxunknow = (abs(u) > 1e7) | (abs(v) > 1e7)
250
+ u[idxunknow] = 0
251
+ v[idxunknow] = 0
252
+ maxu = max(maxu, np.max(u))
253
+ minu = min(minu, np.min(u))
254
+ maxv = max(maxv, np.max(v))
255
+ minv = min(minv, np.min(v))
256
+ rad = np.sqrt(u ** 2 + v ** 2)
257
+ maxrad = max(maxrad, np.max(rad))
258
+ u = u / (maxrad + np.finfo(float).eps)
259
+ v = v / (maxrad + np.finfo(float).eps)
260
+ img = compute_color(u, v)
261
+ out.append(img)
262
+ return np.float32(np.uint8(out))
263
+
264
+
265
+ def pt_flow_to_image(flow):
266
+ """Transfer flow map to image.
267
+ Part of code forked from flownet.
268
+ """
269
+ out = []
270
+ maxu = torch.tensor(-999)
271
+ maxv = torch.tensor(-999)
272
+ minu = torch.tensor(999)
273
+ minv = torch.tensor(999)
274
+ maxrad = torch.tensor(-1)
275
+ if torch.cuda.is_available():
276
+ maxu = maxu.cuda()
277
+ maxv = maxv.cuda()
278
+ minu = minu.cuda()
279
+ minv = minv.cuda()
280
+ maxrad = maxrad.cuda()
281
+ for i in range(flow.shape[0]):
282
+ u = flow[i, 0, :, :]
283
+ v = flow[i, 1, :, :]
284
+ idxunknow = (torch.abs(u) > 1e7) + (torch.abs(v) > 1e7)
285
+ u[idxunknow] = 0
286
+ v[idxunknow] = 0
287
+ maxu = torch.max(maxu, torch.max(u))
288
+ minu = torch.min(minu, torch.min(u))
289
+ maxv = torch.max(maxv, torch.max(v))
290
+ minv = torch.min(minv, torch.min(v))
291
+ rad = torch.sqrt((u ** 2 + v ** 2).float()).to(torch.int64)
292
+ maxrad = torch.max(maxrad, torch.max(rad))
293
+ u = u / (maxrad + torch.finfo(torch.float32).eps)
294
+ v = v / (maxrad + torch.finfo(torch.float32).eps)
295
+ # TODO: change the following to pytorch
296
+ img = pt_compute_color(u, v)
297
+ out.append(img)
298
+
299
+ return torch.stack(out, dim=0)
300
+
301
+
302
+ def highlight_flow(flow):
303
+ """Convert flow into middlebury color code image.
304
+ """
305
+ out = []
306
+ s = flow.shape
307
+ for i in range(flow.shape[0]):
308
+ img = np.ones((s[1], s[2], 3)) * 144.
309
+ u = flow[i, :, :, 0]
310
+ v = flow[i, :, :, 1]
311
+ for h in range(s[1]):
312
+ for w in range(s[1]):
313
+ ui = u[h, w]
314
+ vi = v[h, w]
315
+ img[ui, vi, :] = 255.
316
+ out.append(img)
317
+ return np.float32(np.uint8(out))
318
+
319
+
320
+ def pt_highlight_flow(flow):
321
+ """Convert flow into middlebury color code image.
322
+ """
323
+ out = []
324
+ s = flow.shape
325
+ for i in range(flow.shape[0]):
326
+ img = np.ones((s[1], s[2], 3)) * 144.
327
+ u = flow[i, :, :, 0]
328
+ v = flow[i, :, :, 1]
329
+ for h in range(s[1]):
330
+ for w in range(s[1]):
331
+ ui = u[h, w]
332
+ vi = v[h, w]
333
+ img[ui, vi, :] = 255.
334
+ out.append(img)
335
+ return np.float32(np.uint8(out))
336
+
337
+
338
+ def compute_color(u, v):
339
+ h, w = u.shape
340
+ img = np.zeros([h, w, 3])
341
+ nanIdx = np.isnan(u) | np.isnan(v)
342
+ u[nanIdx] = 0
343
+ v[nanIdx] = 0
344
+ # colorwheel = COLORWHEEL
345
+ colorwheel = make_color_wheel()
346
+ ncols = np.size(colorwheel, 0)
347
+ rad = np.sqrt(u ** 2 + v ** 2)
348
+ a = np.arctan2(-v, -u) / np.pi
349
+ fk = (a + 1) / 2 * (ncols - 1) + 1
350
+ k0 = np.floor(fk).astype(int)
351
+ k1 = k0 + 1
352
+ k1[k1 == ncols + 1] = 1
353
+ f = fk - k0
354
+ for i in range(np.size(colorwheel, 1)):
355
+ tmp = colorwheel[:, i]
356
+ col0 = tmp[k0 - 1] / 255
357
+ col1 = tmp[k1 - 1] / 255
358
+ col = (1 - f) * col0 + f * col1
359
+ idx = rad <= 1
360
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
361
+ notidx = np.logical_not(idx)
362
+ col[notidx] *= 0.75
363
+ img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx)))
364
+ return img
365
+
366
+
367
+ def pt_compute_color(u, v):
368
+ h, w = u.shape
369
+ img = torch.zeros([3, h, w])
370
+ if torch.cuda.is_available():
371
+ img = img.cuda()
372
+ nanIdx = (torch.isnan(u) + torch.isnan(v)) != 0
373
+ u[nanIdx] = 0.
374
+ v[nanIdx] = 0.
375
+ # colorwheel = COLORWHEEL
376
+ colorwheel = pt_make_color_wheel()
377
+ if torch.cuda.is_available():
378
+ colorwheel = colorwheel.cuda()
379
+ ncols = colorwheel.size()[0]
380
+ rad = torch.sqrt((u ** 2 + v ** 2).to(torch.float32))
381
+ a = torch.atan2(-v.to(torch.float32), -u.to(torch.float32)) / np.pi
382
+ fk = (a + 1) / 2 * (ncols - 1) + 1
383
+ k0 = torch.floor(fk).to(torch.int64)
384
+ k1 = k0 + 1
385
+ k1[k1 == ncols + 1] = 1
386
+ f = fk - k0.to(torch.float32)
387
+ for i in range(colorwheel.size()[1]):
388
+ tmp = colorwheel[:, i]
389
+ col0 = tmp[k0 - 1]
390
+ col1 = tmp[k1 - 1]
391
+ col = (1 - f) * col0 + f * col1
392
+ idx = rad <= 1. / 255.
393
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
394
+ notidx = (idx != 0)
395
+ col[notidx] *= 0.75
396
+ img[i, :, :] = col * (1 - nanIdx).to(torch.float32)
397
+ return img
398
+
399
+
400
+ def make_color_wheel():
401
+ RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6)
402
+ ncols = RY + YG + GC + CB + BM + MR
403
+ colorwheel = np.zeros([ncols, 3])
404
+ col = 0
405
+ # RY
406
+ colorwheel[0:RY, 0] = 255
407
+ colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
408
+ col += RY
409
+ # YG
410
+ colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG))
411
+ colorwheel[col:col + YG, 1] = 255
412
+ col += YG
413
+ # GC
414
+ colorwheel[col:col + GC, 1] = 255
415
+ colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
416
+ col += GC
417
+ # CB
418
+ colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB))
419
+ colorwheel[col:col + CB, 2] = 255
420
+ col += CB
421
+ # BM
422
+ colorwheel[col:col + BM, 2] = 255
423
+ colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
424
+ col += + BM
425
+ # MR
426
+ colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
427
+ colorwheel[col:col + MR, 0] = 255
428
+ return colorwheel
429
+
430
+
431
+ def pt_make_color_wheel():
432
+ RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6)
433
+ ncols = RY + YG + GC + CB + BM + MR
434
+ colorwheel = torch.zeros([ncols, 3])
435
+ col = 0
436
+ # RY
437
+ colorwheel[0:RY, 0] = 1.
438
+ colorwheel[0:RY, 1] = torch.arange(0, RY, dtype=torch.float32) / RY
439
+ col += RY
440
+ # YG
441
+ colorwheel[col:col + YG, 0] = 1. - (torch.arange(0, YG, dtype=torch.float32) / YG)
442
+ colorwheel[col:col + YG, 1] = 1.
443
+ col += YG
444
+ # GC
445
+ colorwheel[col:col + GC, 1] = 1.
446
+ colorwheel[col:col + GC, 2] = torch.arange(0, GC, dtype=torch.float32) / GC
447
+ col += GC
448
+ # CB
449
+ colorwheel[col:col + CB, 1] = 1. - (torch.arange(0, CB, dtype=torch.float32) / CB)
450
+ colorwheel[col:col + CB, 2] = 1.
451
+ col += CB
452
+ # BM
453
+ colorwheel[col:col + BM, 2] = 1.
454
+ colorwheel[col:col + BM, 0] = torch.arange(0, BM, dtype=torch.float32) / BM
455
+ col += BM
456
+ # MR
457
+ colorwheel[col:col + MR, 2] = 1. - (torch.arange(0, MR, dtype=torch.float32) / MR)
458
+ colorwheel[col:col + MR, 0] = 1.
459
+ return colorwheel
460
+
461
+
462
+ def is_image_file(filename):
463
+ IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
464
+ filename_lower = filename.lower()
465
+ return any(filename_lower.endswith(extension) for extension in IMG_EXTENSIONS)
466
+
467
+
468
+ def deprocess(img):
469
+ img = img.add_(1).div_(2)
470
+ return img
471
+
472
+
473
+ # get configs
474
+ def get_config(config):
475
+ with open(config, 'r') as stream:
476
+ return yaml.load(stream,Loader=yaml.Loader)
477
+
478
+
479
+ # Get model list for resume
480
+ def get_model_list(dirname, key, iteration=0):
481
+ if os.path.exists(dirname) is False:
482
+ return None
483
+ gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if
484
+ os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f]
485
+ if gen_models is None:
486
+ return None
487
+ gen_models.sort()
488
+ if iteration == 0:
489
+ last_model_name = gen_models[-1]
490
+ else:
491
+ for model_name in gen_models:
492
+ if '{:0>8d}'.format(iteration) in model_name:
493
+ return model_name
494
+ raise ValueError('Not found models with this iteration')
495
+ return last_model_name
496
+
497
+
498
+ if __name__ == '__main__':
499
+ test_random_bbox()
500
+ mask = test_bbox2mask()
501
+ print(mask.shape)
502
+ import matplotlib.pyplot as plt
503
+
504
+ plt.imshow(mask, cmap='gray')
505
+ plt.show()